Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix label handling during flatten #1208

Merged
merged 8 commits into from
Dec 6, 2024
Merged
41 changes: 30 additions & 11 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,12 +699,13 @@ def add(

if label.endswith("[$]"):
label = label[:-3]
if label not in self.labels:
self.labels[label] = []
all_label = "all-" + label + "s"
if all_label not in self.labels:
self.labels[all_label] = []
label_pattern = label + "[{}]"

count = len(self.labels[label])
self.labels[label].append(new_child)
count = len(self.labels[all_label])
self.labels[all_label].append(new_child)
label = label_pattern.format(count)

if not replace and label in self.labels:
Expand Down Expand Up @@ -825,7 +826,21 @@ def _check_if_empty(child):
self.reset_labels()

def reset_labels(self):
"""Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports."""
"""Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports.

Notes
-----
Will renumber the labels in a given Compound. Duplicated labels are named in the format "{name}[$]", where the $ stands in for the 0-indexed
number in the Compound hierarchy with given "name".

i.e. self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"]
and
i.e. self.labels.keys() = ["CH2[1]", "CH2[3]", "CH2[5]"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"]

Additonally, if it doesn't exist, duplicated labels that are numbered as above with the "[$]" will also be put into a list index.
self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] as shown above, but also
have a label of self.labels["all-CH2s"], which is a list of all CH2 children in the Compound.
"""
new_labels = OrderedDict()
hoisted_children = {
key: val
Expand Down Expand Up @@ -856,16 +871,16 @@ def reset_labels(self):
if "port" in label:
label = "port[$]"
else:
label = "{0}[$]".format(child.name)

label = f"{child.name}[$]"
if label.endswith("[$]"):
label = label[:-3]
if label not in new_labels:
new_labels[label] = []
all_label = "all-" + label + "s"
if all_label not in new_labels:
new_labels[all_label] = []
label_pattern = label + "[{}]"

count = len(new_labels[label])
new_labels[label].append(child)
count = len(new_labels[all_label])
new_labels[all_label].append(child)
label = label_pattern.format(count)
new_labels[label] = child
self.labels = new_labels
Expand Down Expand Up @@ -1880,6 +1895,9 @@ def flatten(self, inplace=True):
for neighbor in nx.neighbors(bond_graph, particle):
new_bonds.append((particle, neighbor))

# Remove all labels which refer to children in the hierarchy
self.labels.clear()

# Remove all the children
if inplace:
for child in children_list:
Expand All @@ -1896,6 +1914,7 @@ def flatten(self, inplace=True):
comp = clone(self)
comp.flatten(inplace=True)
return comp
self.reset_labels()

def update_coordinates(self, filename, update_port_locations=True):
"""Update the coordinates of this Compound from a file.
Expand Down
92 changes: 48 additions & 44 deletions mbuild/tests/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def test_add_by_list(self, h2o):
temp_comp.add(comp_list, label=label_list)
a = [k for k, v in temp_comp.labels.items()]
assert a == [
"water",
"all-waters",
"water[0]",
"water[1]",
"water[2]",
Expand Down Expand Up @@ -783,42 +783,14 @@ def test_remove(self, ethane):

# Test to reset labels after hydrogens
ethane6 = mb.clone(ethane)
ethane6.flatten()
hydrogens = ethane6.particles_by_name("H")
ethane6.remove(hydrogens)
ethane6.remove(hydrogens, reset_labels=True)
assert list(ethane6.labels.keys()) == [
"methyl1",
"methyl2",
"C",
"C[0]",
"H",
"C[1]",
"port",
"port[1]",
"port[3]",
"port[5]",
"port[7]",
"port[9]",
"port[11]",
]

ethane7 = mb.clone(ethane)
ethane7.flatten()
hydrogens = ethane7.particles_by_name("H")
ethane7.remove(hydrogens, reset_labels=True)

assert list(ethane7.labels.keys()) == [
"C",
"C[0]",
"C[1]",
"port",
"port[0]",
"port[1]",
"port[2]",
"port[3]",
"port[4]",
"port[5]",
]
assert ethane6.available_ports() == []
assert len(ethane6.all_ports()) == 6

def test_remove_many(self, ethane):
ethane.remove([ethane.children[0], ethane.children[1]])
Expand Down Expand Up @@ -1041,6 +1013,31 @@ def test_flatten_box_of_eth(self, ethane):
box_of_eth.flatten()
assert len(box_of_eth.children) == box_of_eth.n_particles == 8 * 2
assert box_of_eth.n_bonds == 7 * 2
assert list(box_of_eth.labels.keys()) == [
"all-Cs",
"C[0]",
"all-Hs",
"H[0]",
"H[1]",
"H[2]",
"C[1]",
"H[3]",
"H[4]",
"H[5]",
"C[2]",
"H[6]",
"H[7]",
"H[8]",
"C[3]",
"H[9]",
"H[10]",
"H[11]",
]

def test_flatten_then_fill_box(self, benzene):
benzene.flatten(inplace=True)
benzene_box = mb.packing.fill_box(compound=benzene, n_compounds=2, density=0.3)
assert next(iter(benzene_box.particles())).root.bond_graph

def test_flatten_with_port(self, ethane):
ethane.remove(ethane[2])
Expand Down Expand Up @@ -1726,7 +1723,7 @@ def test_energy_minimize_shift_com(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_shift_anchor(self, octane):
anchor_compound = octane.labels["chain"].labels["CH3"][0]
anchor_compound = octane.labels["chain"].labels["CH3[0]"]
pos_old = anchor_compound.pos
octane.energy_minimize(anchor=anchor_compound)
# check to see if COM of the anchor Compound
Expand All @@ -1738,9 +1735,9 @@ def test_energy_minimize_shift_anchor(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_fix_compounds(self, octane):
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
methyl_end1 = octane.labels["chain"].labels["CH3[0]"]
carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
not_in_compound = mb.Compound(name="H")

# fix the whole molecule and make sure positions are close
Expand Down Expand Up @@ -1827,9 +1824,9 @@ def test_energy_minimize_fix_compounds(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_ignore_compounds(self, octane):
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
methyl_end1 = octane.labels["chain"].labels["CH3[1]"]
carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
not_in_compound = mb.Compound(name="H")

# fix the whole molecule and make sure positions are close
Expand Down Expand Up @@ -1859,12 +1856,12 @@ def test_energy_minimize_ignore_compounds(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_distance_constraints(self, octane):
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
methyl_end1 = octane.labels["chain"].labels["CH3[1]"]

carbon_end0 = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
carbon_end1 = octane.labels["chain"].labels["CH3"][1].labels["C"][0]
h_end0 = octane.labels["chain"].labels["CH3"][0].labels["H"][0]
carbon_end0 = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
carbon_end1 = octane.labels["chain"].labels["CH3[1]"].labels["C[0]"]
h_end0 = octane.labels["chain"].labels["CH3[0]"].labels["H[0]"]

not_in_compound = mb.Compound(name="H")

Expand Down Expand Up @@ -2539,3 +2536,10 @@ def test_catalog_bondgraph_types(self, benzene):
catalog_bondgraph_type(compound.children[1][0], compound.bond_graph)
== "particle_graph"
)

def test_reset_labels(self):
ethane = mb.load("CC", smiles=True)
Hs = ethane.particles_by_name("H")
ethane.remove(Hs, reset_labels=True)
ports = set(f"port[{i}]" for i in range(6))
assert ports.issubset(set(ethane.labels.keys()))
2 changes: 1 addition & 1 deletion mbuild/tests/test_json_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_label_consistency(self):
parent.add(CH3())
compound_to_json(parent, "parent.json", include_ports=True)
parent_copy = compound_from_json("parent.json")
assert len(parent_copy["CH2"]) == len(parent["CH2"])
assert len(parent_copy["all-CH2s"]) == len(parent["all-CH2s"])
assert parent_copy.labels.keys() == parent.labels.keys()
for child, child_copy in zip(parent.successors(), parent_copy.successors()):
assert child.labels.keys() == child_copy.labels.keys()
Expand Down
Loading