From 03821526e4590f27e034b16b4ea67af193a81e18 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 2 Aug 2024 16:50:44 -0400 Subject: [PATCH] fix: support shuffle_poscar for other formats Fix #1570. Signed-off-by: Jinzhe Zeng --- dpgen/generator/run.py | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 805bd5695..8f6265a3d 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -216,19 +216,6 @@ def poscar_natoms(lines): return numb_atoms -def poscar_shuffle(poscar_in, poscar_out): - with open(poscar_in) as fin: - lines = list(fin) - numb_atoms = poscar_natoms(lines) - idx = np.arange(8, 8 + numb_atoms) - np.random.shuffle(idx) - out_lines = lines[0:8] - for ii in range(numb_atoms): - out_lines.append(lines[idx[ii]]) - with open(poscar_out, "w") as fout: - fout.write("".join(out_lines)) - - def expand_idx(in_list): ret = [] for ii in in_list: @@ -1272,6 +1259,7 @@ def make_model_devi(iter_index, jdata, mdata): conf_path = os.path.join(work_path, "confs") create_path(conf_path) sys_counter = 0 + rng = np.random.default_rng() for ss in conf_systems: conf_counter = 0 for cc in ss: @@ -1279,17 +1267,9 @@ def make_model_devi(iter_index, jdata, mdata): conf_name = make_model_devi_conf_name( sys_idx[sys_counter], conf_counter ) - orig_poscar_name = conf_name + ".orig.poscar" poscar_name = conf_name + ".poscar" lmp_name = conf_name + ".lmp" - if shuffle_poscar: - os.symlink(cc, os.path.join(conf_path, orig_poscar_name)) - poscar_shuffle( - os.path.join(conf_path, orig_poscar_name), - os.path.join(conf_path, poscar_name), - ) - else: - os.symlink(cc, os.path.join(conf_path, poscar_name)) + os.symlink(cc, os.path.join(conf_path, poscar_name)) if "sys_format" in jdata: fmt = jdata["sys_format"] else: @@ -1299,6 +1279,8 @@ def make_model_devi(iter_index, jdata, mdata): fmt=fmt, type_map=jdata["type_map"], ) + if shuffle_poscar: + system.data["coords"] = rng.permuted(system.data["coords"], axis=1) if jdata.get("model_devi_nopbc", False): system.remove_pbc() system.to_lammps_lmp(os.path.join(conf_path, lmp_name))