diff --git a/dpdata/system.py b/dpdata/system.py index 2613166a..42b3f4e7 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -849,6 +849,7 @@ def perturb( cell_pert_fraction: float, atom_pert_distance: float, atom_pert_style: str = "normal", + atom_pert_prob: float = 1.0, ): """Perturb each frame in the system randomly. The cell will be deformed randomly, and atoms will be displaced by a random distance in random direction. @@ -877,6 +878,8 @@ def perturb( These points are treated as vector used by atoms to move. Obviously, the max length of the distance atoms move is `atom_pert_distance`. - `'const'`: The distance atoms move will be a constant `atom_pert_distance`. + atom_pert_prob : float + Determine the proportion of the total number of atoms in a frame that are perturbed. Returns ------- @@ -900,7 +903,15 @@ def perturb( tmp_system.data["coords"][0] = np.matmul( tmp_system.data["coords"][0], cell_perturb_matrix ) - for kk in range(len(tmp_system.data["coords"][0])): + pert_natoms = int(atom_pert_prob * len(tmp_system.data["coords"][0])) + pert_atom_id = sorted( + np.random.choice( + range(len(tmp_system.data["coords"][0])), + pert_natoms, + replace=False, + ).tolist() + ) + for kk in pert_atom_id: atom_perturb_vector = get_atom_perturb_vector( atom_pert_distance, atom_pert_style ) diff --git a/tests/poscars/POSCAR.SiC.partpert b/tests/poscars/POSCAR.SiC.partpert new file mode 100644 index 00000000..859de7ca --- /dev/null +++ b/tests/poscars/POSCAR.SiC.partpert @@ -0,0 +1,16 @@ +C4 Si4 +1.0 +4.0354487481064565e+00 1.1027270790560616e-17 2.5642993008475204e-17 +2.0693526054669642e-01 4.1066892997402196e+00 -8.6715682899078028e-18 +4.2891472979598610e-01 5.5796885749827474e-01 4.1100061517204542e+00 +C Si +4 4 +Cartesian + 0.03122504 0.15559669 2.1913045 + 1.93908836 -0.08678864 0.06748919 + 0.13114716 2.15827511 0.06333341 + 2.36161952 1.42824405 2.58837618 +-0.03895165 0.12197669 0.05496244 + 1.79528462 2.48830207 -0.55733221 + 2.11363589 0.09280028 2.0301803 + 0.19221505 2.16245144 2.07930701 diff --git a/tests/test_perturb.py b/tests/test_perturb.py index eea71116..5c2f4a7b 100644 --- a/tests/test_perturb.py +++ b/tests/test_perturb.py @@ -12,6 +12,7 @@ class NormalGenerator: def __init__(self): self.randn_generator = self.get_randn_generator() self.rand_generator = self.get_rand_generator() + self.choice_generator = self.get_choice_generator() def randn(self, number): return next(self.randn_generator) @@ -19,6 +20,9 @@ def randn(self, number): def rand(self, number): return next(self.rand_generator) + def choice(self, total_natoms, pert_natoms, replace): + return next(self.choice_generator)[:pert_natoms] + @staticmethod def get_randn_generator(): data = np.asarray( @@ -44,11 +48,16 @@ def get_rand_generator(): [0.23182233, 0.87106847, 0.68728511, 0.94180274, 0.92860453, 0.69191187] ) + @staticmethod + def get_choice_generator(): + yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0]) + class UniformGenerator: def __init__(self): self.randn_generator = self.get_randn_generator() self.rand_generator = self.get_rand_generator() + self.choice_generator = self.get_choice_generator() def randn(self, number): return next(self.randn_generator) @@ -56,6 +65,9 @@ def randn(self, number): def rand(self, number): return next(self.rand_generator) + def choice(self, total_natoms, pert_natoms, replace): + return next(self.choice_generator) + @staticmethod def get_randn_generator(): data = [ @@ -97,11 +109,16 @@ def get_rand_generator(): yield np.asarray(data[count]) count += 1 + @staticmethod + def get_choice_generator(): + yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0]) + class ConstGenerator: def __init__(self): self.randn_generator = self.get_randn_generator() self.rand_generator = self.get_rand_generator() + self.choice_generator = self.get_choice_generator() def randn(self, number): return next(self.randn_generator) @@ -109,6 +126,9 @@ def randn(self, number): def rand(self, number): return next(self.rand_generator) + def choice(self, total_natoms, pert_natoms, replace): + return next(self.choice_generator) + @staticmethod def get_randn_generator(): data = np.asarray( @@ -135,6 +155,10 @@ def get_rand_generator(): [0.01525907, 0.68387374, 0.39768541, 0.55596047, 0.26557088, 0.60883073] ) + @staticmethod + def get_choice_generator(): + yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0]) + # %% class TestPerturbNormal(unittest.TestCase, CompSys, IsPBC): @@ -142,6 +166,7 @@ class TestPerturbNormal(unittest.TestCase, CompSys, IsPBC): def setUp(self, random_mock): random_mock.rand = NormalGenerator().rand random_mock.randn = NormalGenerator().randn + random_mock.choice = NormalGenerator().choice system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar") self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "normal") self.system_2 = dpdata.System("poscars/POSCAR.SiC.normal", fmt="vasp/poscar") @@ -153,6 +178,7 @@ class TestPerturbUniform(unittest.TestCase, CompSys, IsPBC): def setUp(self, random_mock): random_mock.rand = UniformGenerator().rand random_mock.randn = UniformGenerator().randn + random_mock.choice = UniformGenerator().choice system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar") self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "uniform") self.system_2 = dpdata.System("poscars/POSCAR.SiC.uniform", fmt="vasp/poscar") @@ -164,11 +190,24 @@ class TestPerturbConst(unittest.TestCase, CompSys, IsPBC): def setUp(self, random_mock): random_mock.rand = ConstGenerator().rand random_mock.randn = ConstGenerator().randn + random_mock.choice = ConstGenerator().choice system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar") self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "const") self.system_2 = dpdata.System("poscars/POSCAR.SiC.const", fmt="vasp/poscar") self.places = 6 +class TestPerturbPartAtoms(unittest.TestCase, CompSys, IsPBC): + @patch("numpy.random") + def setUp(self, random_mock): + random_mock.rand = NormalGenerator().rand + random_mock.randn = NormalGenerator().randn + random_mock.choice = NormalGenerator().choice + system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar") + self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "normal", 0.25) + self.system_2 = dpdata.System("poscars/POSCAR.SiC.partpert", fmt="vasp/poscar") + self.places = 6 + + if __name__ == "__main__": unittest.main()