From f8a1b6b20eb9ac7f1aa6d20ee806a59e6fb00c11 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Sat, 31 Aug 2024 18:33:46 +0800 Subject: [PATCH] test: support comparison between two multi systems (#705) ## Summary by CodeRabbit - **New Features** - Introduced functions and classes to enhance testing capabilities for multi-system comparisons. - Added validation classes for periodic boundary conditions across multi-system objects. - **Bug Fixes** - Updated test classes to utilize new multi-system handling, improving clarity and functionality. - **Documentation** - Enhanced clarity in variable naming for better alignment with multi-system concepts. --------- Co-authored-by: Han Wang --- tests/comp_sys.py | 84 ++++++++++++++++++++++++++++++++++++++ tests/test_deepmd_mixed.py | 26 +++++++----- 2 files changed, 100 insertions(+), 10 deletions(-) diff --git a/tests/comp_sys.py b/tests/comp_sys.py index 99879af6..a3a916a0 100644 --- a/tests/comp_sys.py +++ b/tests/comp_sys.py @@ -105,6 +105,72 @@ def test_virial(self): ) +def _make_comp_ms_test_func(comp_sys_test_func): + """ + Dynamically generates a test function for multi-system comparisons. + + Args: + comp_sys_test_func (Callable): The original test function for single systems. + + Returns + ------- + Callable: A new test function that can handle comparisons between multi-systems. + """ + + def comp_ms_test_func(iobj): + assert hasattr(iobj, "ms_1") and hasattr( + iobj, "ms_2" + ), "Multi-system objects must be present" + iobj.assertEqual(len(iobj.ms_1), len(iobj.ms_2)) + keys = [ii.formula for ii in iobj.ms_1] + keys_2 = [ii.formula for ii in iobj.ms_2] + assert sorted(keys) == sorted( + keys_2 + ), f"Keys of two MS are not equal: {keys} != {keys_2}" + for kk in keys: + iobj.system_1 = iobj.ms_1[kk] + iobj.system_2 = iobj.ms_2[kk] + comp_sys_test_func(iobj) + del iobj.system_1 + del iobj.system_2 + + return comp_ms_test_func + + +def _make_comp_ms_class(comp_class): + """ + Dynamically generates a test class for multi-system comparisons. + + Args: + comp_class (type): The original test class for single systems. + + Returns + ------- + type: A new test class that can handle comparisons between multi-systems. + """ + + class CompMS: + pass + + test_methods = [ + func + for func in dir(comp_class) + if callable(getattr(comp_class, func)) and func.startswith("test_") + ] + + for func in test_methods: + setattr(CompMS, func, _make_comp_ms_test_func(getattr(comp_class, func))) + + return CompMS + + +# MultiSystems comparison from single System comparison +CompMultiSys = _make_comp_ms_class(CompSys) + +# LabeledMultiSystems comparison from single LabeledSystem comparison +CompLabeledMultiSys = _make_comp_ms_class(CompLabeledSys) + + class MultiSystems: def test_systems_name(self): self.assertEqual(set(self.systems.systems), set(self.system_names)) @@ -127,3 +193,21 @@ class IsNoPBC: def test_is_nopbc(self): self.assertTrue(self.system_1.nopbc) self.assertTrue(self.system_2.nopbc) + + +class MSAllIsPBC: + def test_is_pbc(self): + assert hasattr(self, "ms_1") and hasattr( + self, "ms_2" + ), "Multi-system objects must be present and iterable" + self.assertTrue(all([not ss.nopbc for ss in self.ms_1])) + self.assertTrue(all([not ss.nopbc for ss in self.ms_2])) + + +class MSAllIsNoPBC: + def test_is_nopbc(self): + assert hasattr(self, "ms_1") and hasattr( + self, "ms_2" + ), "Multi-system objects must be present and iterable" + self.assertTrue(all([ss.nopbc for ss in self.ms_1])) + self.assertTrue(all([ss.nopbc for ss in self.ms_2])) diff --git a/tests/test_deepmd_mixed.py b/tests/test_deepmd_mixed.py index 70a09dbe..3c28d1b1 100644 --- a/tests/test_deepmd_mixed.py +++ b/tests/test_deepmd_mixed.py @@ -6,12 +6,18 @@ from glob import glob import numpy as np -from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems +from comp_sys import ( + CompLabeledMultiSys, + CompLabeledSys, + IsNoPBC, + MSAllIsNoPBC, + MultiSystems, +) from context import dpdata class TestMixedMultiSystemsDumpLoad( - unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC ): def setUp(self): self.places = 6 @@ -62,8 +68,8 @@ def setUp(self): self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy") self.systems = dpdata.MultiSystems() self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed") - self.system_1 = self.ms["C1H4A0B0D0"] - self.system_2 = self.systems["C1H4A0B0D0"] + self.ms_1 = self.ms + self.ms_2 = self.systems mixed_sets = glob("tmp.deepmd.mixed/*/set.*") self.assertEqual(len(mixed_sets), 2) for i in mixed_sets: @@ -112,7 +118,7 @@ def test_str(self): class TestMixedMultiSystemsDumpLoadSetSize( - unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC ): def setUp(self): self.places = 6 @@ -163,8 +169,8 @@ def setUp(self): self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy") self.systems = dpdata.MultiSystems() self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed") - self.system_1 = self.ms["C1H4A0B0D0"] - self.system_2 = self.systems["C1H4A0B0D0"] + self.ms_1 = self.ms + self.ms_2 = self.systems mixed_sets = glob("tmp.deepmd.mixed/*/set.*") self.assertEqual(len(mixed_sets), 5) for i in mixed_sets: @@ -213,7 +219,7 @@ def test_str(self): class TestMixedMultiSystemsTypeChange( - unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC ): def setUp(self): self.places = 6 @@ -265,8 +271,8 @@ def setUp(self): self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy") self.systems = dpdata.MultiSystems(type_map=["TOKEN"]) self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed") - self.system_1 = self.ms["TOKEN0C1H4A0B0D0"] - self.system_2 = self.systems["TOKEN0C1H4A0B0D0"] + self.ms_1 = self.ms + self.ms_2 = self.systems mixed_sets = glob("tmp.deepmd.mixed/*/set.*") self.assertEqual(len(mixed_sets), 2) for i in mixed_sets: