diff --git a/mbuild/compound.py b/mbuild/compound.py index e8966f46b..ec61e83b5 100644 --- a/mbuild/compound.py +++ b/mbuild/compound.py @@ -699,12 +699,13 @@ def add( if label.endswith("[$]"): label = label[:-3] - if label not in self.labels: - self.labels[label] = [] + all_label = "all-" + label + "s" + if all_label not in self.labels: + self.labels[all_label] = [] label_pattern = label + "[{}]" - count = len(self.labels[label]) - self.labels[label].append(new_child) + count = len(self.labels[all_label]) + self.labels[all_label].append(new_child) label = label_pattern.format(count) if not replace and label in self.labels: @@ -825,7 +826,21 @@ def _check_if_empty(child): self.reset_labels() def reset_labels(self): - """Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports.""" + """Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports. + + Notes + ----- + Will renumber the labels in a given Compound. Duplicated labels are named in the format "{name}[$]", where the $ stands in for the 0-indexed + number in the Compound hierarchy with given "name". + + i.e. self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] + and + i.e. self.labels.keys() = ["CH2[1]", "CH2[3]", "CH2[5]"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] + + Additonally, if it doesn't exist, duplicated labels that are numbered as above with the "[$]" will also be put into a list index. + self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] as shown above, but also + have a label of self.labels["all-CH2s"], which is a list of all CH2 children in the Compound. + """ new_labels = OrderedDict() hoisted_children = { key: val @@ -856,16 +871,16 @@ def reset_labels(self): if "port" in label: label = "port[$]" else: - label = "{0}[$]".format(child.name) - + label = f"{child.name}[$]" if label.endswith("[$]"): label = label[:-3] - if label not in new_labels: - new_labels[label] = [] + all_label = "all-" + label + "s" + if all_label not in new_labels: + new_labels[all_label] = [] label_pattern = label + "[{}]" - count = len(new_labels[label]) - new_labels[label].append(child) + count = len(new_labels[all_label]) + new_labels[all_label].append(child) label = label_pattern.format(count) new_labels[label] = child self.labels = new_labels @@ -1880,6 +1895,9 @@ def flatten(self, inplace=True): for neighbor in nx.neighbors(bond_graph, particle): new_bonds.append((particle, neighbor)) + # Remove all labels which refer to children in the hierarchy + self.labels.clear() + # Remove all the children if inplace: for child in children_list: @@ -1896,6 +1914,7 @@ def flatten(self, inplace=True): comp = clone(self) comp.flatten(inplace=True) return comp + self.reset_labels() def update_coordinates(self, filename, update_port_locations=True): """Update the coordinates of this Compound from a file. diff --git a/mbuild/tests/test_compound.py b/mbuild/tests/test_compound.py index b3aa548bd..7d10b78ec 100644 --- a/mbuild/tests/test_compound.py +++ b/mbuild/tests/test_compound.py @@ -604,7 +604,7 @@ def test_add_by_list(self, h2o): temp_comp.add(comp_list, label=label_list) a = [k for k, v in temp_comp.labels.items()] assert a == [ - "water", + "all-waters", "water[0]", "water[1]", "water[2]", @@ -783,42 +783,14 @@ def test_remove(self, ethane): # Test to reset labels after hydrogens ethane6 = mb.clone(ethane) - ethane6.flatten() hydrogens = ethane6.particles_by_name("H") - ethane6.remove(hydrogens) + ethane6.remove(hydrogens, reset_labels=True) assert list(ethane6.labels.keys()) == [ "methyl1", "methyl2", - "C", - "C[0]", - "H", - "C[1]", - "port", - "port[1]", - "port[3]", - "port[5]", - "port[7]", - "port[9]", - "port[11]", - ] - - ethane7 = mb.clone(ethane) - ethane7.flatten() - hydrogens = ethane7.particles_by_name("H") - ethane7.remove(hydrogens, reset_labels=True) - - assert list(ethane7.labels.keys()) == [ - "C", - "C[0]", - "C[1]", - "port", - "port[0]", - "port[1]", - "port[2]", - "port[3]", - "port[4]", - "port[5]", ] + assert ethane6.available_ports() == [] + assert len(ethane6.all_ports()) == 6 def test_remove_many(self, ethane): ethane.remove([ethane.children[0], ethane.children[1]]) @@ -1041,6 +1013,31 @@ def test_flatten_box_of_eth(self, ethane): box_of_eth.flatten() assert len(box_of_eth.children) == box_of_eth.n_particles == 8 * 2 assert box_of_eth.n_bonds == 7 * 2 + assert list(box_of_eth.labels.keys()) == [ + "all-Cs", + "C[0]", + "all-Hs", + "H[0]", + "H[1]", + "H[2]", + "C[1]", + "H[3]", + "H[4]", + "H[5]", + "C[2]", + "H[6]", + "H[7]", + "H[8]", + "C[3]", + "H[9]", + "H[10]", + "H[11]", + ] + + def test_flatten_then_fill_box(self, benzene): + benzene.flatten(inplace=True) + benzene_box = mb.packing.fill_box(compound=benzene, n_compounds=2, density=0.3) + assert next(iter(benzene_box.particles())).root.bond_graph def test_flatten_with_port(self, ethane): ethane.remove(ethane[2]) @@ -1726,7 +1723,7 @@ def test_energy_minimize_shift_com(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_shift_anchor(self, octane): - anchor_compound = octane.labels["chain"].labels["CH3"][0] + anchor_compound = octane.labels["chain"].labels["CH3[0]"] pos_old = anchor_compound.pos octane.energy_minimize(anchor=anchor_compound) # check to see if COM of the anchor Compound @@ -1738,9 +1735,9 @@ def test_energy_minimize_shift_anchor(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_fix_compounds(self, octane): - methyl_end0 = octane.labels["chain"].labels["CH3"][0] - methyl_end1 = octane.labels["chain"].labels["CH3"][1] - carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0] + methyl_end0 = octane.labels["chain"].labels["CH3[0]"] + methyl_end1 = octane.labels["chain"].labels["CH3[0]"] + carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"] not_in_compound = mb.Compound(name="H") # fix the whole molecule and make sure positions are close @@ -1827,9 +1824,9 @@ def test_energy_minimize_fix_compounds(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_ignore_compounds(self, octane): - methyl_end0 = octane.labels["chain"].labels["CH3"][0] - methyl_end1 = octane.labels["chain"].labels["CH3"][1] - carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0] + methyl_end0 = octane.labels["chain"].labels["CH3[0]"] + methyl_end1 = octane.labels["chain"].labels["CH3[1]"] + carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"] not_in_compound = mb.Compound(name="H") # fix the whole molecule and make sure positions are close @@ -1859,12 +1856,12 @@ def test_energy_minimize_ignore_compounds(self, octane): "win" in sys.platform, reason="Unknown issue with Window's Open Babel " ) def test_energy_minimize_distance_constraints(self, octane): - methyl_end0 = octane.labels["chain"].labels["CH3"][0] - methyl_end1 = octane.labels["chain"].labels["CH3"][1] + methyl_end0 = octane.labels["chain"].labels["CH3[0]"] + methyl_end1 = octane.labels["chain"].labels["CH3[1]"] - carbon_end0 = octane.labels["chain"].labels["CH3"][0].labels["C"][0] - carbon_end1 = octane.labels["chain"].labels["CH3"][1].labels["C"][0] - h_end0 = octane.labels["chain"].labels["CH3"][0].labels["H"][0] + carbon_end0 = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"] + carbon_end1 = octane.labels["chain"].labels["CH3[1]"].labels["C[0]"] + h_end0 = octane.labels["chain"].labels["CH3[0]"].labels["H[0]"] not_in_compound = mb.Compound(name="H") @@ -2553,3 +2550,10 @@ def test_load_large_smiles(self): smiles=True, ) assert cpd.n_particles == 244 + + def test_reset_labels(self): + ethane = mb.load("CC", smiles=True) + Hs = ethane.particles_by_name("H") + ethane.remove(Hs, reset_labels=True) + ports = set(f"port[{i}]" for i in range(6)) + assert ports.issubset(set(ethane.labels.keys())) diff --git a/mbuild/tests/test_json_formats.py b/mbuild/tests/test_json_formats.py index edcfe9604..523ea65f5 100644 --- a/mbuild/tests/test_json_formats.py +++ b/mbuild/tests/test_json_formats.py @@ -99,7 +99,7 @@ def test_label_consistency(self): parent.add(CH3()) compound_to_json(parent, "parent.json", include_ports=True) parent_copy = compound_from_json("parent.json") - assert len(parent_copy["CH2"]) == len(parent["CH2"]) + assert len(parent_copy["all-CH2s"]) == len(parent["all-CH2s"]) assert parent_copy.labels.keys() == parent.labels.keys() for child, child_copy in zip(parent.successors(), parent_copy.successors()): assert child.labels.keys() == child_copy.labels.keys()