Skip to content

Commit

Permalink
Feat: Support specifying proportion of atoms to be perturbed in System (
Browse files Browse the repository at this point in the history
#716)

See title.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Introduced a new parameter for controlled atom perturbation in the
perturb function, enhancing flexibility.

- **Bug Fixes**
- Improved logic for selecting atoms to perturb, ensuring only a
specified proportion is affected.

- **Tests**
- Added a new test class to validate the perturbation functionality for
atomic systems, increasing test coverage and reliability.
- Introduced a structured representation of a Silicon Carbide crystal
for validation in tests.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Chengqian-Zhang and pre-commit-ci[bot] authored Sep 11, 2024
1 parent 1de5ace commit fb942bb
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
13 changes: 12 additions & 1 deletion dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ def perturb(
cell_pert_fraction: float,
atom_pert_distance: float,
atom_pert_style: str = "normal",
atom_pert_prob: float = 1.0,
):
"""Perturb each frame in the system randomly.
The cell will be deformed randomly, and atoms will be displaced by a random distance in random direction.
Expand Down Expand Up @@ -877,6 +878,8 @@ def perturb(
These points are treated as vector used by atoms to move.
Obviously, the max length of the distance atoms move is `atom_pert_distance`.
- `'const'`: The distance atoms move will be a constant `atom_pert_distance`.
atom_pert_prob : float
Determine the proportion of the total number of atoms in a frame that are perturbed.
Returns
-------
Expand All @@ -900,7 +903,15 @@ def perturb(
tmp_system.data["coords"][0] = np.matmul(
tmp_system.data["coords"][0], cell_perturb_matrix
)
for kk in range(len(tmp_system.data["coords"][0])):
pert_natoms = int(atom_pert_prob * len(tmp_system.data["coords"][0]))
pert_atom_id = sorted(
np.random.choice(
range(len(tmp_system.data["coords"][0])),
pert_natoms,
replace=False,
).tolist()
)
for kk in pert_atom_id:
atom_perturb_vector = get_atom_perturb_vector(
atom_pert_distance, atom_pert_style
)
Expand Down
16 changes: 16 additions & 0 deletions tests/poscars/POSCAR.SiC.partpert
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
C4 Si4
1.0
4.0354487481064565e+00 1.1027270790560616e-17 2.5642993008475204e-17
2.0693526054669642e-01 4.1066892997402196e+00 -8.6715682899078028e-18
4.2891472979598610e-01 5.5796885749827474e-01 4.1100061517204542e+00
C Si
4 4
Cartesian
0.03122504 0.15559669 2.1913045
1.93908836 -0.08678864 0.06748919
0.13114716 2.15827511 0.06333341
2.36161952 1.42824405 2.58837618
-0.03895165 0.12197669 0.05496244
1.79528462 2.48830207 -0.55733221
2.11363589 0.09280028 2.0301803
0.19221505 2.16245144 2.07930701
39 changes: 39 additions & 0 deletions tests/test_perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ class NormalGenerator:
def __init__(self):
self.randn_generator = self.get_randn_generator()
self.rand_generator = self.get_rand_generator()
self.choice_generator = self.get_choice_generator()

def randn(self, number):
return next(self.randn_generator)

def rand(self, number):
return next(self.rand_generator)

def choice(self, total_natoms, pert_natoms, replace):
return next(self.choice_generator)[:pert_natoms]

@staticmethod
def get_randn_generator():
data = np.asarray(
Expand All @@ -44,18 +48,26 @@ def get_rand_generator():
[0.23182233, 0.87106847, 0.68728511, 0.94180274, 0.92860453, 0.69191187]
)

@staticmethod
def get_choice_generator():
yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0])


class UniformGenerator:
def __init__(self):
self.randn_generator = self.get_randn_generator()
self.rand_generator = self.get_rand_generator()
self.choice_generator = self.get_choice_generator()

def randn(self, number):
return next(self.randn_generator)

def rand(self, number):
return next(self.rand_generator)

def choice(self, total_natoms, pert_natoms, replace):
return next(self.choice_generator)

@staticmethod
def get_randn_generator():
data = [
Expand Down Expand Up @@ -97,18 +109,26 @@ def get_rand_generator():
yield np.asarray(data[count])
count += 1

@staticmethod
def get_choice_generator():
yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0])


class ConstGenerator:
def __init__(self):
self.randn_generator = self.get_randn_generator()
self.rand_generator = self.get_rand_generator()
self.choice_generator = self.get_choice_generator()

def randn(self, number):
return next(self.randn_generator)

def rand(self, number):
return next(self.rand_generator)

def choice(self, total_natoms, pert_natoms, replace):
return next(self.choice_generator)

@staticmethod
def get_randn_generator():
data = np.asarray(
Expand All @@ -135,13 +155,18 @@ def get_rand_generator():
[0.01525907, 0.68387374, 0.39768541, 0.55596047, 0.26557088, 0.60883073]
)

@staticmethod
def get_choice_generator():
yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0])


# %%
class TestPerturbNormal(unittest.TestCase, CompSys, IsPBC):
@patch("numpy.random")
def setUp(self, random_mock):
random_mock.rand = NormalGenerator().rand
random_mock.randn = NormalGenerator().randn
random_mock.choice = NormalGenerator().choice
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "normal")
self.system_2 = dpdata.System("poscars/POSCAR.SiC.normal", fmt="vasp/poscar")
Expand All @@ -153,6 +178,7 @@ class TestPerturbUniform(unittest.TestCase, CompSys, IsPBC):
def setUp(self, random_mock):
random_mock.rand = UniformGenerator().rand
random_mock.randn = UniformGenerator().randn
random_mock.choice = UniformGenerator().choice
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "uniform")
self.system_2 = dpdata.System("poscars/POSCAR.SiC.uniform", fmt="vasp/poscar")
Expand All @@ -164,11 +190,24 @@ class TestPerturbConst(unittest.TestCase, CompSys, IsPBC):
def setUp(self, random_mock):
random_mock.rand = ConstGenerator().rand
random_mock.randn = ConstGenerator().randn
random_mock.choice = ConstGenerator().choice
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "const")
self.system_2 = dpdata.System("poscars/POSCAR.SiC.const", fmt="vasp/poscar")
self.places = 6


class TestPerturbPartAtoms(unittest.TestCase, CompSys, IsPBC):
@patch("numpy.random")
def setUp(self, random_mock):
random_mock.rand = NormalGenerator().rand
random_mock.randn = NormalGenerator().randn
random_mock.choice = NormalGenerator().choice
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "normal", 0.25)
self.system_2 = dpdata.System("poscars/POSCAR.SiC.partpert", fmt="vasp/poscar")
self.places = 6


if __name__ == "__main__":
unittest.main()

0 comments on commit fb942bb

Please sign in to comment.