Skip to content

Commit

Permalink
fix: support shuffle_poscar for other formats (#1610)
Browse files Browse the repository at this point in the history
Fix #1570.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced efficiency in handling atomic configurations by directly
modifying coordinates in memory, eliminating the need for intermediate
files.
	- Simplified logic for shuffling atomic indices.

- **Bug Fixes**
- Removed obsolete function that may have caused confusion in the
workflow related to POSCAR file management.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Aug 3, 2024
1 parent cdfef61 commit 39115f2
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,6 @@ def poscar_natoms(lines):
return numb_atoms


def poscar_shuffle(poscar_in, poscar_out):
with open(poscar_in) as fin:
lines = list(fin)
numb_atoms = poscar_natoms(lines)
idx = np.arange(8, 8 + numb_atoms)
np.random.shuffle(idx)
out_lines = lines[0:8]
for ii in range(numb_atoms):
out_lines.append(lines[idx[ii]])
with open(poscar_out, "w") as fout:
fout.write("".join(out_lines))


def expand_idx(in_list):
ret = []
for ii in in_list:
Expand Down Expand Up @@ -1272,24 +1259,17 @@ def make_model_devi(iter_index, jdata, mdata):
conf_path = os.path.join(work_path, "confs")
create_path(conf_path)
sys_counter = 0
rng = np.random.default_rng()
for ss in conf_systems:
conf_counter = 0
for cc in ss:
if model_devi_engine == "lammps":
conf_name = make_model_devi_conf_name(
sys_idx[sys_counter], conf_counter
)
orig_poscar_name = conf_name + ".orig.poscar"
poscar_name = conf_name + ".poscar"
lmp_name = conf_name + ".lmp"
if shuffle_poscar:
os.symlink(cc, os.path.join(conf_path, orig_poscar_name))
poscar_shuffle(
os.path.join(conf_path, orig_poscar_name),
os.path.join(conf_path, poscar_name),
)
else:
os.symlink(cc, os.path.join(conf_path, poscar_name))
os.symlink(cc, os.path.join(conf_path, poscar_name))
if "sys_format" in jdata:
fmt = jdata["sys_format"]
else:
Expand All @@ -1299,6 +1279,8 @@ def make_model_devi(iter_index, jdata, mdata):
fmt=fmt,
type_map=jdata["type_map"],
)
if shuffle_poscar:
system.data["coords"] = rng.permuted(system.data["coords"], axis=1)
if jdata.get("model_devi_nopbc", False):
system.remove_pbc()
system.to_lammps_lmp(os.path.join(conf_path, lmp_name))
Expand Down

0 comments on commit 39115f2

Please sign in to comment.