From f8a1b6b20eb9ac7f1aa6d20ee806a59e6fb00c11 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Sat, 31 Aug 2024 18:33:46 +0800 Subject: [PATCH 1/2] test: support comparison between two multi systems (#705) ## Summary by CodeRabbit - **New Features** - Introduced functions and classes to enhance testing capabilities for multi-system comparisons. - Added validation classes for periodic boundary conditions across multi-system objects. - **Bug Fixes** - Updated test classes to utilize new multi-system handling, improving clarity and functionality. - **Documentation** - Enhanced clarity in variable naming for better alignment with multi-system concepts. --------- Co-authored-by: Han Wang --- tests/comp_sys.py | 84 ++++++++++++++++++++++++++++++++++++++ tests/test_deepmd_mixed.py | 26 +++++++----- 2 files changed, 100 insertions(+), 10 deletions(-) diff --git a/tests/comp_sys.py b/tests/comp_sys.py index 99879af6..a3a916a0 100644 --- a/tests/comp_sys.py +++ b/tests/comp_sys.py @@ -105,6 +105,72 @@ def test_virial(self): ) +def _make_comp_ms_test_func(comp_sys_test_func): + """ + Dynamically generates a test function for multi-system comparisons. + + Args: + comp_sys_test_func (Callable): The original test function for single systems. + + Returns + ------- + Callable: A new test function that can handle comparisons between multi-systems. + """ + + def comp_ms_test_func(iobj): + assert hasattr(iobj, "ms_1") and hasattr( + iobj, "ms_2" + ), "Multi-system objects must be present" + iobj.assertEqual(len(iobj.ms_1), len(iobj.ms_2)) + keys = [ii.formula for ii in iobj.ms_1] + keys_2 = [ii.formula for ii in iobj.ms_2] + assert sorted(keys) == sorted( + keys_2 + ), f"Keys of two MS are not equal: {keys} != {keys_2}" + for kk in keys: + iobj.system_1 = iobj.ms_1[kk] + iobj.system_2 = iobj.ms_2[kk] + comp_sys_test_func(iobj) + del iobj.system_1 + del iobj.system_2 + + return comp_ms_test_func + + +def _make_comp_ms_class(comp_class): + """ + Dynamically generates a test class for multi-system comparisons. + + Args: + comp_class (type): The original test class for single systems. + + Returns + ------- + type: A new test class that can handle comparisons between multi-systems. + """ + + class CompMS: + pass + + test_methods = [ + func + for func in dir(comp_class) + if callable(getattr(comp_class, func)) and func.startswith("test_") + ] + + for func in test_methods: + setattr(CompMS, func, _make_comp_ms_test_func(getattr(comp_class, func))) + + return CompMS + + +# MultiSystems comparison from single System comparison +CompMultiSys = _make_comp_ms_class(CompSys) + +# LabeledMultiSystems comparison from single LabeledSystem comparison +CompLabeledMultiSys = _make_comp_ms_class(CompLabeledSys) + + class MultiSystems: def test_systems_name(self): self.assertEqual(set(self.systems.systems), set(self.system_names)) @@ -127,3 +193,21 @@ class IsNoPBC: def test_is_nopbc(self): self.assertTrue(self.system_1.nopbc) self.assertTrue(self.system_2.nopbc) + + +class MSAllIsPBC: + def test_is_pbc(self): + assert hasattr(self, "ms_1") and hasattr( + self, "ms_2" + ), "Multi-system objects must be present and iterable" + self.assertTrue(all([not ss.nopbc for ss in self.ms_1])) + self.assertTrue(all([not ss.nopbc for ss in self.ms_2])) + + +class MSAllIsNoPBC: + def test_is_nopbc(self): + assert hasattr(self, "ms_1") and hasattr( + self, "ms_2" + ), "Multi-system objects must be present and iterable" + self.assertTrue(all([ss.nopbc for ss in self.ms_1])) + self.assertTrue(all([ss.nopbc for ss in self.ms_2])) diff --git a/tests/test_deepmd_mixed.py b/tests/test_deepmd_mixed.py index 70a09dbe..3c28d1b1 100644 --- a/tests/test_deepmd_mixed.py +++ b/tests/test_deepmd_mixed.py @@ -6,12 +6,18 @@ from glob import glob import numpy as np -from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems +from comp_sys import ( + CompLabeledMultiSys, + CompLabeledSys, + IsNoPBC, + MSAllIsNoPBC, + MultiSystems, +) from context import dpdata class TestMixedMultiSystemsDumpLoad( - unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC ): def setUp(self): self.places = 6 @@ -62,8 +68,8 @@ def setUp(self): self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy") self.systems = dpdata.MultiSystems() self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed") - self.system_1 = self.ms["C1H4A0B0D0"] - self.system_2 = self.systems["C1H4A0B0D0"] + self.ms_1 = self.ms + self.ms_2 = self.systems mixed_sets = glob("tmp.deepmd.mixed/*/set.*") self.assertEqual(len(mixed_sets), 2) for i in mixed_sets: @@ -112,7 +118,7 @@ def test_str(self): class TestMixedMultiSystemsDumpLoadSetSize( - unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC ): def setUp(self): self.places = 6 @@ -163,8 +169,8 @@ def setUp(self): self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy") self.systems = dpdata.MultiSystems() self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed") - self.system_1 = self.ms["C1H4A0B0D0"] - self.system_2 = self.systems["C1H4A0B0D0"] + self.ms_1 = self.ms + self.ms_2 = self.systems mixed_sets = glob("tmp.deepmd.mixed/*/set.*") self.assertEqual(len(mixed_sets), 5) for i in mixed_sets: @@ -213,7 +219,7 @@ def test_str(self): class TestMixedMultiSystemsTypeChange( - unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC ): def setUp(self): self.places = 6 @@ -265,8 +271,8 @@ def setUp(self): self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy") self.systems = dpdata.MultiSystems(type_map=["TOKEN"]) self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed") - self.system_1 = self.ms["TOKEN0C1H4A0B0D0"] - self.system_2 = self.systems["TOKEN0C1H4A0B0D0"] + self.ms_1 = self.ms + self.ms_2 = self.systems mixed_sets = glob("tmp.deepmd.mixed/*/set.*") self.assertEqual(len(mixed_sets), 2) for i in mixed_sets: From 6d082f162dc05c0980567a7544a8024b7c024a78 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Sat, 31 Aug 2024 18:39:10 +0800 Subject: [PATCH 2/2] test: mixed data format: test if the index_map (when type_map is provided) works (#706) consider the PR after #705 ## Summary by CodeRabbit - **New Features** - Enhanced testing framework for multi-system comparisons, allowing for dynamic generation of test functions and classes. - Introduced new test classes to validate properties of multi-system objects regarding periodic boundary conditions. - Added a new test class for handling type mapping in labeled systems. - **Bug Fixes** - Updated existing test classes to improve clarity and consistency in naming conventions for multi-system variables. --------- Co-authored-by: Han Wang --- tests/test_deepmd_mixed.py | 113 +++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/test_deepmd_mixed.py b/tests/test_deepmd_mixed.py index 3c28d1b1..9f0c5ed4 100644 --- a/tests/test_deepmd_mixed.py +++ b/tests/test_deepmd_mixed.py @@ -117,6 +117,119 @@ def test_str(self): ) +class TestMixedMultiSystemsDumpLoadTypeMap( + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC +): + def setUp(self): + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 6 + + # C1H4 + system_1 = dpdata.LabeledSystem( + "gaussian/methane.gaussianlog", fmt="gaussian/log" + ) + + # C1H3 + system_2 = dpdata.LabeledSystem( + "gaussian/methane_sub.gaussianlog", fmt="gaussian/log" + ) + + tmp_data = system_1.data.copy() + tmp_data["atom_numbs"] = [1, 1, 1, 2] + tmp_data["atom_names"] = ["C", "H", "A", "B"] + tmp_data["atom_types"] = np.array([0, 1, 2, 3, 3]) + # C1H1A1B2 + system_1_modified_type_1 = dpdata.LabeledSystem(data=tmp_data) + + tmp_data = system_1.data.copy() + tmp_data["atom_numbs"] = [1, 1, 2, 1] + tmp_data["atom_names"] = ["C", "H", "A", "B"] + tmp_data["atom_types"] = np.array([0, 1, 2, 2, 3]) + # C1H1A2B1 + system_1_modified_type_2 = dpdata.LabeledSystem(data=tmp_data) + + tmp_data = system_1.data.copy() + tmp_data["atom_numbs"] = [1, 1, 1, 2] + tmp_data["atom_names"] = ["C", "H", "A", "D"] + tmp_data["atom_types"] = np.array([0, 1, 2, 3, 3]) + # C1H1A1C2 + system_1_modified_type_3 = dpdata.LabeledSystem(data=tmp_data) + + self.ms = dpdata.MultiSystems( + system_1, + system_2, + system_1_modified_type_1, + system_1_modified_type_2, + system_1_modified_type_3, + ) + + self.ms.to_deepmd_npy_mixed("tmp.deepmd.mixed") + self.place_holder_ms = dpdata.MultiSystems() + self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy") + + new_type_map = ["H", "C", "D", "A", "B"] + self.systems = dpdata.MultiSystems() + self.systems.from_deepmd_npy_mixed( + "tmp.deepmd.mixed", fmt="deepmd/npy/mixed", type_map=new_type_map + ) + for kk in [ii.formula for ii in self.ms]: + # apply type_map to each system + self.ms[kk].apply_type_map(new_type_map) + # revise keys in dict according because the type_map is updated. + tmp_ss = self.ms.systems.pop(kk) + self.ms.systems[tmp_ss.formula] = tmp_ss + + self.ms_1 = self.ms + self.ms_2 = self.systems + mixed_sets = glob("tmp.deepmd.mixed/*/set.*") + self.assertEqual(len(mixed_sets), 2) + for i in mixed_sets: + self.assertEqual( + os.path.exists(os.path.join(i, "real_atom_types.npy")), True + ) + + self.system_names = [ + "H4C1D0A0B0", + "H3C1D0A0B0", + "H1C1D0A1B2", + "H1C1D0A2B1", + "H1C1D2A1B0", + ] + self.system_sizes = { + "H4C1D0A0B0": 1, + "H3C1D0A0B0": 1, + "H1C1D0A1B2": 1, + "H1C1D0A2B1": 1, + "H1C1D2A1B0": 1, + } + self.atom_names = ["H", "C", "D", "A", "B"] + + def tearDown(self): + if os.path.exists("tmp.deepmd.mixed"): + shutil.rmtree("tmp.deepmd.mixed") + + def test_len(self): + self.assertEqual(len(self.ms), 5) + self.assertEqual(len(self.place_holder_ms), 2) + self.assertEqual(len(self.systems), 5) + + def test_get_nframes(self): + self.assertEqual(self.ms.get_nframes(), 5) + self.assertEqual(self.place_holder_ms.get_nframes(), 5) + self.assertEqual(self.systems.get_nframes(), 5) + + def test_str(self): + self.assertEqual(str(self.ms), "MultiSystems (5 systems containing 5 frames)") + self.assertEqual( + str(self.place_holder_ms), "MultiSystems (2 systems containing 5 frames)" + ) + self.assertEqual( + str(self.systems), "MultiSystems (5 systems containing 5 frames)" + ) + + class TestMixedMultiSystemsDumpLoadSetSize( unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC ):