diff --git a/tests/generator/comp_sys.py b/tests/generator/comp_sys.py index 8806ddb5e..db37ad843 100644 --- a/tests/generator/comp_sys.py +++ b/tests/generator/comp_sys.py @@ -86,6 +86,9 @@ def test_coord(self): tmp_cell = self.system_1.data["cells"] tmp_cell = np.reshape(tmp_cell, [-1, 3]) tmp_cell_norm = np.reshape(np.linalg.norm(tmp_cell, axis=1), [-1, 3]) + if np.max(np.abs(tmp_cell_norm)) < 1e-12: + # zero cell, no pbc case, set to [1., 1., 1.] + tmp_cell_norm = np.ones(tmp_cell_norm.shape) for ff in range(self.system_1.get_nframes()): for ii in range(sum(self.system_1.data["atom_numbs"])): for jj in range(3): @@ -103,12 +106,21 @@ class CompLabeledSys(CompSys): def test_energy(self): self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes()) for ff in range(self.system_1.get_nframes()): - self.assertAlmostEqual( - self.system_1.data["energies"][ff], - self.system_2.data["energies"][ff], - places=self.e_places, - msg="energies[%d] failed" % (ff), - ) + if abs(self.system_2.data["energies"][ff]) < 1e-12: + self.assertAlmostEqual( + self.system_1.data["energies"][ff], + self.system_2.data["energies"][ff], + places=self.e_places, + msg="energies[%d] failed" % (ff), + ) + else: + self.assertAlmostEqual( + self.system_1.data["energies"][ff] + / self.system_2.data["energies"][ff], + 1.0, + places=self.e_places, + msg="energies[%d] failed" % (ff), + ) def test_force(self): self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes())