From d33ba07adffa27784a99eec2ddaecb87d14dc8b4 Mon Sep 17 00:00:00 2001 From: CalCraven <54594941+CalCraven@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:35:24 -0600 Subject: [PATCH] fix label handling during flatten (#1208) * fix label handling during flatten * Change the reset_labels method for compound.py to label the container lists with the format 'all-{name}s' for clarity * Add Ruff to pre-commit hooks (#1207) * update CI and precommit files * add ruff changes * remove gmso lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change error type in test * raise the error that is created * remove duplicate windows 3.12 test * fix precommit errors * fix import error * fix CI error --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix labeling of windows compounds * fix references in monomers tests --------- Co-authored-by: Chris Jones <50423140+chrisjonesBSU@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- mbuild/compound.py | 41 ++++++++++---- mbuild/tests/test_compound.py | 92 ++++++++++++++++--------------- mbuild/tests/test_json_formats.py | 2 +- 3 files changed, 79 insertions(+), 56 deletions(-) diff --git a/mbuild/compound.py b/mbuild/compound.py index e8966f46b..ec61e83b5 100644 --- a/mbuild/compound.py +++ b/mbuild/compound.py @@ -699,12 +699,13 @@ def add( if label.endswith("[$]"): label = label[:-3] - if label not in self.labels: - self.labels[label] = [] + all_label = "all-" + label + "s" + if all_label not in self.labels: + self.labels[all_label] = [] label_pattern = label + "[{}]" - count = len(self.labels[label]) - self.labels[label].append(new_child) + count = len(self.labels[all_label]) + self.labels[all_label].append(new_child) label = label_pattern.format(count) if not replace and label in self.labels: @@ -825,7 +826,21 @@ def _check_if_empty(child): self.reset_labels() def reset_labels(self): - """Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports.""" + """Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports. + + Notes + ----- + Will renumber the labels in a given Compound. Duplicated labels are named in the format "{name}[$]", where the $ stands in for the 0-indexed + number in the Compound hierarchy with given "name". + + i.e. self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] + and + i.e. self.labels.keys() = ["CH2[1]", "CH2[3]", "CH2[5]"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] + + Additonally, if it doesn't exist, duplicated labels that are numbered as above with the "[$]" will also be put into a list index. + self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] as shown above, but also + have a label of self.labels["all-CH2s"], which is a list of all CH2 children in the Compound. + """ new_labels = OrderedDict() hoisted_children = { key: val @@ -856,16 +871,16 @@ def reset_labels(self): if "port" in label: label = "port[$]" else: - label = "{0}[$]".format(child.name) - + label = f"{child.name}[$]" if label.endswith("[$]"): label = label[:-3] - if label not in new_labels: - new_labels[label] = [] + all_label = "all-" + label + "s" + if all_label not in new_labels: + new_labels[all_label] = [] label_pattern = label + "[{}]" - count = len(new_labels[label]) - new_labels[label].append(child) + count = len(new_labels[all_label]) + new_labels[all_label].append(child) label = label_pattern.format(count) new_labels[label] = child self.labels = new_labels @@ -1880,6 +1895,9 @@ def flatten(self, inplace=True): for neighbor in nx.neighbors(bond_graph, particle): new_bonds.append((particle, neighbor)) + # Remove all labels which refer to children in the hierarchy + self.labels.clear() + # Remove all the children if inplace: for child in children_list: @@ -1896,6 +1914,7 @@ def flatten(self, inplace=True): comp = clone(self) comp.flatten(inplace=True) return comp + self.reset_labels() def update_coordinates(self, filename, update_port_locations=True): """Update the coordinates of this Compound from a file. diff --git a/mbuild/tests/test_compound.py b/mbuild/tests/test_compound.py index bb8761187..5345c49a9 100644 --- a/mbuild/tests/test_compound.py +++ b/mbuild/tests/test_compound.py @@ -604,7 +604,7 @@ def test_add_by_list(self, h2o): temp_comp.add(comp_list, label=label_list) a = [k for k, v in temp_comp.labels.items()] assert a == [ - "water", + "all-waters", "water[0]", "water[1]", "water[2]", @@ -783,42 +783,14 @@ def test_remove(self, ethane): # Test to reset labels after hydrogens ethane6 = mb.clone(ethane) - ethane6.flatten() hydrogens = ethane6.particles_by_name("H") - ethane6.remove(hydrogens) + ethane6.remove(hydrogens, reset_labels=True) assert list(ethane6.labels.keys()) == [ "methyl1", "methyl2", - "C", - "C[0]", - "H", - "C[1]", - "port", - "port[1]", - "port[3]", - "port[5]", - "port[7]", - "port[9]", - "port[11]", - ] - - ethane7 = mb.clone(ethane) - ethane7.flatten() - hydrogens = ethane7.particles_by_name("H") - ethane7.remove(hydrogens, reset_labels=True) - - assert list(ethane7.labels.keys()) == [ - "C", - "C[0]", - "C[1]", - "port", - "port[0]", - "port[1]", - "port[2]", - "port[3]", - "port[4]", - "port[5]", ] + assert ethane6.available_ports() == [] + assert len(ethane6.all_ports()) == 6 def test_remove_many(self, ethane): ethane.remove([ethane.children[0], ethane.children[1]]) @@ -1041,6 +1013,31 @@ def test_flatten_box_of_eth(self, ethane): box_of_eth.flatten() assert len(box_of_eth.children) == box_of_eth.n_particles == 8 * 2 assert box_of_eth.n_bonds == 7 * 2 + assert list(box_of_eth.labels.keys()) == [ + "all-Cs", + "C[0]", + "all-Hs", + "H[0]", + "H[1]", + "H[2]", + "C[1]", + "H[3]", + "H[4]", + "H[5]", + "C[2]", + "H[6]", + "H[7]", + "H[8]", + "C[3]", + "H[9]", + "H[10]", + "H[11]", + ] + + def test_flatten_then_fill_box(self, benzene): + benzene.flatten(inplace=True) + benzene_box = mb.packing.fill_box(compound=benzene, n_compounds=2, density=0.3) + assert next(iter(benzene_box.particles())).root.bond_graph def test_flatten_with_port(self, ethane): ethane.remove(ethane[2]) @@ -1726,7 +1723,7 @@ def test_energy_minimize_shift_com(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_shift_anchor(self, octane): - anchor_compound = octane.labels["chain"].labels["CH3"][0] + anchor_compound = octane.labels["chain"].labels["CH3[0]"] pos_old = anchor_compound.pos octane.energy_minimize(anchor=anchor_compound) # check to see if COM of the anchor Compound @@ -1738,9 +1735,9 @@ def test_energy_minimize_shift_anchor(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_fix_compounds(self, octane): - methyl_end0 = octane.labels["chain"].labels["CH3"][0] - methyl_end1 = octane.labels["chain"].labels["CH3"][1] - carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0] + methyl_end0 = octane.labels["chain"].labels["CH3[0]"] + methyl_end1 = octane.labels["chain"].labels["CH3[0]"] + carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"] not_in_compound = mb.Compound(name="H") # fix the whole molecule and make sure positions are close @@ -1827,9 +1824,9 @@ def test_energy_minimize_fix_compounds(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_ignore_compounds(self, octane): - methyl_end0 = octane.labels["chain"].labels["CH3"][0] - methyl_end1 = octane.labels["chain"].labels["CH3"][1] - carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0] + methyl_end0 = octane.labels["chain"].labels["CH3[0]"] + methyl_end1 = octane.labels["chain"].labels["CH3[1]"] + carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"] not_in_compound = mb.Compound(name="H") # fix the whole molecule and make sure positions are close @@ -1859,12 +1856,12 @@ def test_energy_minimize_ignore_compounds(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_distance_constraints(self, octane): - methyl_end0 = octane.labels["chain"].labels["CH3"][0] - methyl_end1 = octane.labels["chain"].labels["CH3"][1] + methyl_end0 = octane.labels["chain"].labels["CH3[0]"] + methyl_end1 = octane.labels["chain"].labels["CH3[1]"] - carbon_end0 = octane.labels["chain"].labels["CH3"][0].labels["C"][0] - carbon_end1 = octane.labels["chain"].labels["CH3"][1].labels["C"][0] - h_end0 = octane.labels["chain"].labels["CH3"][0].labels["H"][0] + carbon_end0 = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"] + carbon_end1 = octane.labels["chain"].labels["CH3[1]"].labels["C[0]"] + h_end0 = octane.labels["chain"].labels["CH3[0]"].labels["H[0]"] not_in_compound = mb.Compound(name="H") @@ -2539,3 +2536,10 @@ def test_catalog_bondgraph_types(self, benzene): catalog_bondgraph_type(compound.children[1][0], compound.bond_graph) == "particle_graph" ) + + def test_reset_labels(self): + ethane = mb.load("CC", smiles=True) + Hs = ethane.particles_by_name("H") + ethane.remove(Hs, reset_labels=True) + ports = set(f"port[{i}]" for i in range(6)) + assert ports.issubset(set(ethane.labels.keys())) diff --git a/mbuild/tests/test_json_formats.py b/mbuild/tests/test_json_formats.py index edcfe9604..523ea65f5 100644 --- a/mbuild/tests/test_json_formats.py +++ b/mbuild/tests/test_json_formats.py @@ -99,7 +99,7 @@ def test_label_consistency(self): parent.add(CH3()) compound_to_json(parent, "parent.json", include_ports=True) parent_copy = compound_from_json("parent.json") - assert len(parent_copy["CH2"]) == len(parent["CH2"]) + assert len(parent_copy["all-CH2s"]) == len(parent["all-CH2s"]) assert parent_copy.labels.keys() == parent.labels.keys() for child, child_copy in zip(parent.successors(), parent_copy.successors()): assert child.labels.keys() == child_copy.labels.keys()