Skip to content

Commit

Permalink
Merge pull request #253 from UC-Davis-molecular-computing/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
dave-doty authored Feb 19, 2024
2 parents bd65b5c + fb9793c commit 7bc00bb
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 129 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
.idea

output

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ In more detail, there are five main types of objects you create to describe your

- `DomainConstraint`: This only looks at a single `Domain`. In practice this is not used much, since there's not much information in a `Domain` other than its DNA sequence, so a `SequenceConstraint` or `NumpyConstraint` typically would already have filtered out any DNA sequence not satisfying such a constraint.

- `StrandConstraint`: This evaluates a whole `Strand`. A common example is that NUPACK's `pfunc` should indicate a complex free energy above a certain threshold, indicating the `Strand` has little secondary structure. This example constraint is available in the library by calling [nupack_strand_complex_free_energy_constraint](https://nuad.readthedocs.io/en/latest/#constraints.nupack_strand_complex_free_energy_constraint).
- `StrandConstraint`: This evaluates a whole `Strand`. A common example is that NUPACK's `pfunc` should indicate a complex free energy above a certain threshold, indicating the `Strand` has little secondary structure. This example constraint is available in the library by calling [nupack_strand_free_energy_constraint](https://nuad.readthedocs.io/en/latest/#constraints.nupack_strand_free_energy_constraint).

- `DomainPairConstraint`: This evaluates a pair of `Domain`'s.

Expand Down
94 changes: 40 additions & 54 deletions examples/sst_canvas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from dataclasses import dataclass
from typing import Optional
import argparse
Expand All @@ -16,16 +15,19 @@ def main() -> None:
args: CLArgs = parse_command_line_arguments()

design = create_design(width=args.width, height=args.height)
thresholds = Thresholds()
constraints = create_constraints(design, thresholds)

constraints = create_constraints(design)

params = ns.SearchParameters(
constraints=constraints,
out_directory=args.directory,
restart=args.restart,
random_seed=args.seed,
log_time=True,
scrolling_output=False,
save_report_for_all_updates=True,
force_overwrite=args.force_overwrite,
# log_time=True,
)

ns.search_for_sequences(design, params)


Expand All @@ -47,6 +49,9 @@ class CLArgs:
seed: Optional[int] = None
"""seed for random number generator; set to fixed integer for reproducibility"""

force_overwrite: bool = False
"""whether to overwrite output files without prompting the user"""


def parse_command_line_arguments() -> CLArgs:
default_directory = os.path.join('output', ns.script_name_no_ext())
Expand Down Expand Up @@ -74,13 +79,19 @@ def parse_command_line_arguments() -> CLArgs:
'numbering from there (i.e., the next files to be written upon improving the '
'design will have index 85).')

parser.add_argument('-f', '--force', action='store_true',
help='If true, then overwrites the output files without prompting the user.')

args = parser.parse_args()

return CLArgs(directory=args.output_dir,
width=args.width,
height=args.height,
seed=args.seed,
restart=args.restart)
return CLArgs(
directory=args.output_dir,
width=args.width,
height=args.height,
seed=args.seed,
restart=args.restart,
force_overwrite=args.force,
)


def create_design(width: int, height: int) -> nc.Design:
Expand Down Expand Up @@ -205,54 +216,20 @@ class Thresholds:
"""RNAduplex complex free energy threshold for pairs tiles with 1 complementary domain."""


def create_constraints(design: nc.Design, thresholds: Thresholds) -> List[nc.Constraint]:
def create_constraints(design: nc.Design) -> List[nc.Constraint]:
thresholds = Thresholds()

strand_individual_ss_constraint = nc.nupack_strand_free_energy_constraint(
threshold=thresholds.tile_ss, temperature=thresholds.temperature, short_description='StrandSS')

# This reduces the number of times we have to create these sets from quadratic to linear
unstarred_domains_sets = {}
starred_domains_sets = {}
for strand in design.strands:
unstarred_domains_sets[strand.name] = strand.unstarred_domains_set()
starred_domains_sets[strand.name] = strand.starred_domains_set()

# determine which pairs of strands have 0 complementary domains and which have 1
# so we can set different RNAduplex energy constraints for each of them
strand_pairs_0_comp = []
strand_pairs_1_comp = []
for strand1, strand2 in itertools.combinations_with_replacement(design.strands, 2):
domains1_unstarred = unstarred_domains_sets[strand1.name]
domains2_unstarred = unstarred_domains_sets[strand2.name]
domains1_starred = starred_domains_sets[strand1.name]
domains2_starred = starred_domains_sets[strand2.name]

complementary_domains = (domains1_unstarred & domains2_starred) | \
(domains2_unstarred & domains1_starred)
complementary_domain_names = [domain.name for domain in complementary_domains]
num_complementary_domains = len(complementary_domain_names)

if num_complementary_domains == 0:
strand_pairs_0_comp.append((strand1, strand2))
elif num_complementary_domains == 1:
strand_pairs_1_comp.append((strand1, strand2))
else:
raise AssertionError('each pair of strands should have exactly 0 or 1 complementary domains')

strand_pairs_rna_duplex_constraint_0comp = nc.rna_duplex_strand_pairs_constraint(
threshold=thresholds.tile_pair_0comp, temperature=thresholds.temperature,
short_description='StrandPairRNA0Comp', pairs=strand_pairs_0_comp)
strand_pairs_rna_duplex_constraint_1comp = nc.rna_duplex_strand_pairs_constraint(
threshold=thresholds.tile_pair_1comp, temperature=thresholds.temperature,
short_description='StrandPairRNA1Comp', pairs=strand_pairs_1_comp)
strand_pairs_rna_duplex_constraint_0comp, strand_pairs_rna_duplex_constraint_1comp = \
nc.rna_duplex_strand_pairs_constraints_by_number_matching_domains(
thresholds={0: thresholds.tile_pair_0comp, 1: thresholds.tile_pair_1comp},
temperature=thresholds.temperature,
short_descriptions={0: 'StrandPairRNA0Comp', 1: 'StrandPairRNA1Comp'},
strands=design.strands,
)

# We already forbid GGGG in any domain, but let's also ensure we don't get GGGG in any strand
# i.e., forbid GGGG that comes from concatenating domains, e.g.,
#
# * ***
# ACGATCGATG GGGATGCATGA
# +==========--===========>
# |
# +==========--===========]
no_gggg_constraint = create_tile_no_gggg_constraint(weight=100)

return [
Expand All @@ -268,6 +245,15 @@ def create_tile_no_gggg_constraint(weight: float) -> nc.StrandConstraint:
# sufficient. See also source code of provided constraints in dsd/constraints.py for more examples,
# particularly for examples that call NUPACK or ViennaRNA.

# We already forbid GGGG in any domain, but let's also ensure we don't get GGGG in any strand
# i.e., forbid GGGG that comes from concatenating domains, e.g.,
#
# * ***
# ACGATCGATG GGGATGCATGA
# +==========--===========>
# |
# +==========--===========]

def evaluate(seqs: Tuple[str, ...], strand: Optional[nc.Strand]) -> nc.Result: # noqa
sequence = seqs[0]
if 'GGGG' in sequence:
Expand Down
63 changes: 58 additions & 5 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions notebooks/nuad_parallel_time_trials.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,35 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"id": "ed9e9749-2d37-4bbc-9996-9621fbfe6efb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.61 s ± 366 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"2.04 s ± 45.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"293 ms ± 8.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"import nuad.vienna_nupack as nv\n",
"\n",
"def pfunc_all(seqs):\n",
" for s1, s2 in seqs:\n",
" p = nv.pfunc((s1,s2))\n",
"\n",
"length = 90\n",
"seqs = [(nv.random_dna_seq(length), nv.random_dna_seq(length)) for _ in range(500)]\n",
"%timeit pfunc_all(seqs)\n",
"%timeit nv.rna_duplex_multiple(seqs)\n",
"%timeit nv.rna_plex_multiple(seqs)"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down
2 changes: 1 addition & 1 deletion nuad/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '0.4.6' # version line; WARNING: do not remove or change this line or comment
version = '0.4.7' # version line; WARNING: do not remove or change this line or comment
50 changes: 27 additions & 23 deletions nuad/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ class DomainPool(JSONSerializable):

replace_with_close_sequences: bool = True
"""
If True, instead of picking a sequence uniformly at random from all those satisfying the constraints
If True, instead of picking a sequence uniformly at random from all those satisfying the filters
when returning a sequence from :meth:`DomainPool.generate_sequence`,
one is picked "close" in Hamming distance to the previous sequence of the :any:`Domain`.
The field :data:`DomainPool.hamming_probability` is used to pick a distance at random, after which
Expand Down Expand Up @@ -1633,7 +1633,7 @@ def from_json_serializable(json_map: Dict[str, Any],
@property
def name(self) -> str:
"""
:return: :any:`DomainPool` of this :any:`Domain`
:return: name of this :any:`Domain`
"""
return self._name

Expand Down Expand Up @@ -2651,7 +2651,7 @@ def fixed(self) -> bool:
"""True if every :any:`Domain` on this :any:`Strand` has a fixed DNA sequence."""
return all(domain.fixed for domain in self.domains)

def unfixed_domains(self) -> Tuple[Domain]:
def unfixed_domains(self) -> Tuple[Domain, ...]:
"""
:return: all :any:`Domain`'s in this :any:`Strand` where :data:`Domain.fixed` is False
"""
Expand Down Expand Up @@ -4851,7 +4851,7 @@ def nupack_domain_free_energy_constraint(
"""
_check_nupack_installed()

def evaluate(seqs: Tuple[str], _: Domain | None) -> Result:
def evaluate(seqs: Tuple[str, ...], _: Domain | None) -> Result:
sequence = seqs[0]
energy = nv.free_energy_single_strand(sequence, temperature, sodium, magnesium)
excess = max(0.0, threshold - energy)
Expand Down Expand Up @@ -4919,7 +4919,7 @@ def nupack_strand_free_energy_constraint(
"""
_check_nupack_installed()

def evaluate(seqs: Tuple[str], _: Strand | None) -> Result:
def evaluate(seqs: Tuple[str, ...], _: Strand | None) -> Result:
sequence = seqs[0]
energy = nv.free_energy_single_strand(sequence, temperature, sodium, magnesium)
excess = max(0.0, threshold - energy)
Expand Down Expand Up @@ -5124,7 +5124,7 @@ def nupack_strand_pair_constraints_by_number_matching_domains(
if descriptions is None:
descriptions = {
num_matching: (_pair_default_description('strand', 'NUPACK', threshold, temperature) +
f'\nfor strands with {num_matching} complementary '
f' for strands with {num_matching} complementary '
f'{"domain" if num_matching == 1 else "domains"}')
for num_matching, threshold in thresholds.items()
}
Expand Down Expand Up @@ -5485,7 +5485,7 @@ def evaluate_bulk(domain_pairs: Iterable[DomainPair]) -> List[Result]:

def get_domain_pairs_from_thresholds_dict(
thresholds: Dict[Tuple[Domain, bool, Domain, bool] | Tuple[Domain, Domain], Tuple[float, float]]
) -> Tuple[DomainPair]:
) -> Tuple[DomainPair, ...]:
# gather pairs of domains referenced in `thresholds`
domain_pairs = []
for key, _ in thresholds.items():
Expand All @@ -5508,9 +5508,11 @@ def get_domain_pairs_from_thresholds_dict(
return domain_pairs


S = TypeVar('S', str, bytes, bytearray)

PairsEvaluationFunction = Callable[
[Sequence[Tuple[str, str]], logging.Logger, float, str, float],
Tuple[float]
[Sequence[Tuple[S, S]], logging.Logger, float, str, float],
Tuple[float, ...]
]


Expand Down Expand Up @@ -5701,6 +5703,9 @@ def rna_plex_domain_pairs_nonorthogonal_constraint(
:param parameters_filename:
name of parameters file for ViennaRNA; default is
same as :py:meth:`vienna_nupack.rna_duplex_multiple`
:param max_energy:
maximum energy to return; if the RNAplex returns a value larger than this, then
this value is used instead
:return:
constraint
"""
Expand Down Expand Up @@ -5889,7 +5894,7 @@ def __call__(self, *,
weight: float = 1.0,
score_transfer_function: Callable[[float], float] = default_score_transfer_function,
description: str | None = None,
short_description: str,
short_description: str = '',
parallel: bool = False,
pairs: Iterable[Tuple[Strand, Strand]] | None = None,
) -> SPC: ...
Expand Down Expand Up @@ -5955,7 +5960,7 @@ def _strand_pairs_constraints_by_number_matching_domains(
def _normalize_domains_pairs_disjoint_parameters(
domains: Iterable[Domain] | None,
pairs: Iterable[Tuple[Domain, Domain]],
check_domain_against_itself: bool) -> Iterable[Tuple[Domain, Domain]]:
check_domain_against_itself: bool) -> Tuple[Tuple[Domain, Domain], ...]:
# Enforce that exactly one of domains or pairs is not None, and if domains is specified,
# set pairs to be all pairs from domains. Return those pairs; if pairs is specified,
# just return it. Also normalize to return a tuple.
Expand Down Expand Up @@ -6020,7 +6025,7 @@ def rna_cofold_strand_pairs_constraints_by_number_matching_domains(
if descriptions is None:
descriptions = {
num_matching: (_pair_default_description('strand', 'RNAcofold', threshold, temperature) +
f'\nfor strands with {num_matching} complementary '
f' for strands with {num_matching} complementary '
f'{"domain" if num_matching == 1 else "domains"}')
for num_matching, threshold in thresholds.items()
}
Expand Down Expand Up @@ -6100,7 +6105,7 @@ def rna_duplex_strand_pairs_constraints_by_number_matching_domains(
if descriptions is None:
descriptions = {
num_matching: (_pair_default_description('strand', 'RNAduplex', threshold, temperature) +
f'\nfor strands with {num_matching} complementary '
f' for strands with {num_matching} complementary '
f'{"domain" if num_matching == 1 else "domains"}')
for num_matching, threshold in thresholds.items()
}
Expand Down Expand Up @@ -6187,7 +6192,7 @@ def rna_plex_strand_pairs_constraints_by_number_matching_domains(
if descriptions is None:
descriptions = {
num_matching: (_pair_default_description('strand', 'RNAplex', threshold, temperature) +
f'\nfor strands with {num_matching} complementary '
f' for strands with {num_matching} complementary '
f'{"domain" if num_matching == 1 else "domains"}')
for num_matching, threshold in thresholds.items()
}
Expand Down Expand Up @@ -6639,22 +6644,21 @@ def lcs_strand_pairs_constraint_with_dummy_parameters(
*,
threshold: float,
temperature: float = nv.default_temperature,
weight_: float = 1.0,
score_transfer_function_: Callable[[float], float] = default_score_transfer_function,
weight: float = 1.0,
score_transfer_function: Callable[[float], float] = default_score_transfer_function,
description: str | None = None,
short_description: str = 'lcs strand pairs',
parallel_: bool = False,
pairs_: Iterable[Tuple[Strand, Strand]] | None = None,
parameters_filename_: str = nv.default_vienna_rna_parameter_filename
parallel: bool = False,
pairs: Iterable[Tuple[Strand, Strand]] | None = None,
) -> StrandPairsConstraint:
threshold_int = int(threshold)
return lcs_strand_pairs_constraint(
threshold=threshold_int,
weight=weight_,
score_transfer_function=score_transfer_function_,
weight=weight,
score_transfer_function=score_transfer_function,
description=description,
short_description=short_description,
pairs=pairs_,
pairs=pairs,
check_strand_against_itself=True,
# TODO: rewrite signature of other strand pair constraints to include this
gc_double=gc_double,
Expand All @@ -6663,7 +6667,7 @@ def lcs_strand_pairs_constraint_with_dummy_parameters(
if descriptions is None:
descriptions = {
num_matching: (f'Longest complementary subsequence between strands is > {threshold}' +
f'\nfor strands with {num_matching} complementary '
f' for strands with {num_matching} complementary '
f'{"domain" if num_matching == 1 else "domains"}')
for num_matching, threshold in thresholds.items()
}
Expand Down
Loading

0 comments on commit 7bc00bb

Please sign in to comment.