From 39115f2ce4ea0d7740e27a81e3c9d3c4c30e0b49 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 3 Aug 2024 05:10:11 -0400 Subject: [PATCH] fix: support `shuffle_poscar` for other formats (#1610) Fix #1570. ## Summary by CodeRabbit - **New Features** - Enhanced efficiency in handling atomic configurations by directly modifying coordinates in memory, eliminating the need for intermediate files. - Simplified logic for shuffling atomic indices. - **Bug Fixes** - Removed obsolete function that may have caused confusion in the workflow related to POSCAR file management. Signed-off-by: Jinzhe Zeng --- dpgen/generator/run.py | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 805bd5695..8f6265a3d 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -216,19 +216,6 @@ def poscar_natoms(lines): return numb_atoms -def poscar_shuffle(poscar_in, poscar_out): - with open(poscar_in) as fin: - lines = list(fin) - numb_atoms = poscar_natoms(lines) - idx = np.arange(8, 8 + numb_atoms) - np.random.shuffle(idx) - out_lines = lines[0:8] - for ii in range(numb_atoms): - out_lines.append(lines[idx[ii]]) - with open(poscar_out, "w") as fout: - fout.write("".join(out_lines)) - - def expand_idx(in_list): ret = [] for ii in in_list: @@ -1272,6 +1259,7 @@ def make_model_devi(iter_index, jdata, mdata): conf_path = os.path.join(work_path, "confs") create_path(conf_path) sys_counter = 0 + rng = np.random.default_rng() for ss in conf_systems: conf_counter = 0 for cc in ss: @@ -1279,17 +1267,9 @@ def make_model_devi(iter_index, jdata, mdata): conf_name = make_model_devi_conf_name( sys_idx[sys_counter], conf_counter ) - orig_poscar_name = conf_name + ".orig.poscar" poscar_name = conf_name + ".poscar" lmp_name = conf_name + ".lmp" - if shuffle_poscar: - os.symlink(cc, os.path.join(conf_path, orig_poscar_name)) - poscar_shuffle( - os.path.join(conf_path, orig_poscar_name), - os.path.join(conf_path, poscar_name), - ) - else: - os.symlink(cc, os.path.join(conf_path, poscar_name)) + os.symlink(cc, os.path.join(conf_path, poscar_name)) if "sys_format" in jdata: fmt = jdata["sys_format"] else: @@ -1299,6 +1279,8 @@ def make_model_devi(iter_index, jdata, mdata): fmt=fmt, type_map=jdata["type_map"], ) + if shuffle_poscar: + system.data["coords"] = rng.permuted(system.data["coords"], axis=1) if jdata.get("model_devi_nopbc", False): system.remove_pbc() system.to_lammps_lmp(os.path.join(conf_path, lmp_name))