Skip to content

Commit

Permalink
fix: reuse regular methods for deepmd/mixed
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Aug 29, 2024
1 parent 676517a commit db6298a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 170 deletions.
2 changes: 0 additions & 2 deletions dpdata/deepmd/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def to_system_data(folder, type_map=None, labels=True):
"orig",
"cells",
"coords",
"real_atom_types",
"real_atom_names",
"nopbc",
"energies",
Expand Down Expand Up @@ -189,7 +188,6 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
"orig",
"cells",
"coords",
"real_atom_types",
"real_atom_names",
"nopbc",
"energies",
Expand Down
179 changes: 17 additions & 162 deletions dpdata/deepmd/mixed.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,16 @@
from __future__ import annotations

import copy
import glob
import os
import shutil

import numpy as np


def load_type(folder):
data = {}
data["atom_names"] = []
# if find type_map.raw, use it
assert os.path.isfile(
os.path.join(folder, "type_map.raw")
), "Mixed type system must have type_map.raw!"
with open(os.path.join(folder, "type_map.raw")) as fp:
data["atom_names"] = fp.read().split()

return data


def formula(atom_names, atom_numbs):
"""Return the formula of this system, like C3H5O2."""
return "".join([f"{symbol}{numb}" for symbol, numb in zip(atom_names, atom_numbs)])


def _cond_load_data(fname):
tmp = None
if os.path.isfile(fname):
tmp = np.load(fname)
return tmp


def _load_set(folder, nopbc: bool):
coords = np.load(os.path.join(folder, "coord.npy"))
if nopbc:
cells = np.zeros((coords.shape[0], 3, 3))
else:
cells = np.load(os.path.join(folder, "box.npy"))
eners = _cond_load_data(os.path.join(folder, "energy.npy"))
forces = _cond_load_data(os.path.join(folder, "force.npy"))
virs = _cond_load_data(os.path.join(folder, "virial.npy"))
real_atom_types = np.load(os.path.join(folder, "real_atom_types.npy"))
return cells, coords, eners, forces, virs, real_atom_types
from .comp import dump as comp_dump
from .comp import to_system_data as comp_to_system_data


def to_system_data(folder, type_map=None, labels=True):
data = comp_to_system_data(folder, type_map, labels)
# data is empty
data = load_type(folder)
old_type_map = data["atom_names"].copy()
if type_map is not None:
assert isinstance(type_map, list)
Expand All @@ -60,50 +22,16 @@ def to_system_data(folder, type_map=None, labels=True):
data["atom_names"] = type_map.copy()
else:
index_map = None
data["orig"] = np.zeros([3])
if os.path.isfile(os.path.join(folder, "nopbc")):
data["nopbc"] = True
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
all_cells = []
all_coords = []
all_eners = []
all_forces = []
all_virs = []
all_real_atom_types = []
for ii in sets:
cells, coords, eners, forces, virs, real_atom_types = _load_set(
ii, data.get("nopbc", False)
)
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
if index_map is None:
all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1]))
else:
all_real_atom_types.append(
np.reshape(index_map[real_atom_types], [nframes, -1])
)
if eners is not None:
eners = np.reshape(eners, [nframes])
if labels:
if eners is not None and eners.size > 0:
all_eners.append(np.reshape(eners, [nframes]))
if forces is not None and forces.size > 0:
all_forces.append(np.reshape(forces, [nframes, -1, 3]))
if virs is not None and virs.size > 0:
all_virs.append(np.reshape(virs, [nframes, 3, 3]))
all_cells_concat = np.concatenate(all_cells, axis=0)
all_coords_concat = np.concatenate(all_coords, axis=0)
all_real_atom_types_concat = np.concatenate(all_real_atom_types, axis=0)
all_eners_concat = None
all_forces_concat = None
all_virs_concat = None
if len(all_eners) > 0:
all_eners_concat = np.concatenate(all_eners, axis=0)
if len(all_forces) > 0:
all_forces_concat = np.concatenate(all_forces, axis=0)
if len(all_virs) > 0:
all_virs_concat = np.concatenate(all_virs, axis=0)
all_real_atom_types_concat = data.pop("real_atom_types").astype(int)
if index_map is not None:
all_real_atom_types_concat = index_map[all_real_atom_types_concat]

Check warning on line 27 in dpdata/deepmd/mixed.py

View check run for this annotation

Codecov / codecov/patch

dpdata/deepmd/mixed.py#L27

Added line #L27 was not covered by tests
all_cells_concat = data["cells"]
all_coords_concat = data["coords"]
if labels:
all_eners_concat = data.get("energies")
all_forces_concat = data.get("forces")
all_virs_concat = data.get("virials")

data_list = []
while True:
if all_real_atom_types_concat.size == 0:
Expand Down Expand Up @@ -143,20 +71,6 @@ def to_system_data(folder, type_map=None, labels=True):


def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):
os.makedirs(folder, exist_ok=True)
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
if len(sets) > 0:
if remove_sets:
for ii in sets:
shutil.rmtree(ii)
else:
raise RuntimeError(
"found "
+ str(sets)
+ " in "
+ folder
+ "not a clean deepmd raw dir. please firstly clean set.* then try compress"
)
# if not converted to mixed
if "real_atom_types" not in data:
from dpdata import LabeledSystem, System
Expand All @@ -169,69 +83,10 @@ def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):
else:
temp_sys = System(data=data)
temp_sys.convert_to_mixed_type()
# dump raw
np.savetxt(os.path.join(folder, "type.raw"), data["atom_types"], fmt="%d")
np.savetxt(os.path.join(folder, "type_map.raw"), data["real_atom_names"], fmt="%s")
# BondOrder System
if "bonds" in data:
np.savetxt(
os.path.join(folder, "bonds.raw"),
data["bonds"],
header="begin_atom, end_atom, bond_order",
)
if "formal_charges" in data:
np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"])
# reshape frame properties and convert prec
nframes = data["cells"].shape[0]
cells = np.reshape(data["cells"], [nframes, 9]).astype(comp_prec)
coords = np.reshape(data["coords"], [nframes, -1]).astype(comp_prec)
eners = None
forces = None
virials = None
real_atom_types = None
if "energies" in data:
eners = np.reshape(data["energies"], [nframes]).astype(comp_prec)
if "forces" in data:
forces = np.reshape(data["forces"], [nframes, -1]).astype(comp_prec)
if "virials" in data:
virials = np.reshape(data["virials"], [nframes, 9]).astype(comp_prec)
if "atom_pref" in data:
atom_pref = np.reshape(data["atom_pref"], [nframes, -1]).astype(comp_prec)
if "real_atom_types" in data:
real_atom_types = np.reshape(data["real_atom_types"], [nframes, -1]).astype(
np.int64
)
# dump frame properties: cell, coord, energy, force and virial
nsets = nframes // set_size
if set_size * nsets < nframes:
nsets += 1
for ii in range(nsets):
set_stt = ii * set_size
set_end = (ii + 1) * set_size
set_folder = os.path.join(folder, "set.%06d" % ii)
os.makedirs(set_folder)
np.save(os.path.join(set_folder, "box"), cells[set_stt:set_end])
np.save(os.path.join(set_folder, "coord"), coords[set_stt:set_end])
if eners is not None:
np.save(os.path.join(set_folder, "energy"), eners[set_stt:set_end])
if forces is not None:
np.save(os.path.join(set_folder, "force"), forces[set_stt:set_end])
if virials is not None:
np.save(os.path.join(set_folder, "virial"), virials[set_stt:set_end])
if real_atom_types is not None:
np.save(
os.path.join(set_folder, "real_atom_types"),
real_atom_types[set_stt:set_end],
)
if "atom_pref" in data:
np.save(os.path.join(set_folder, "atom_pref"), atom_pref[set_stt:set_end])
try:
os.remove(os.path.join(folder, "nopbc"))
except OSError:
pass
if data.get("nopbc", False):
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
pass

data = data.copy()
data["atom_names"] = data.pop("real_atom_names")
comp_dump(folder, data, set_size, comp_prec, remove_sets)


def mix_system(*system, type_map, **kwargs):
Expand Down
10 changes: 4 additions & 6 deletions dpdata/deepmd/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ def load_type(folder, type_map=None):
int
)
ntypes = np.max(data["atom_types"]) + 1
data["atom_numbs"] = []
for ii in range(ntypes):
data["atom_numbs"].append(np.count_nonzero(data["atom_types"] == ii))
data["atom_names"] = []
# if find type_map.raw, use it
if os.path.isfile(os.path.join(folder, "type_map.raw")):
Expand All @@ -30,9 +27,10 @@ def load_type(folder, type_map=None):
my_type_map = []
for ii in range(ntypes):
my_type_map.append("Type_%d" % ii)
assert len(my_type_map) >= len(data["atom_numbs"])
for ii in range(len(data["atom_numbs"])):
data["atom_names"].append(my_type_map[ii])
data["atom_names"] = my_type_map
data["atom_numbs"] = []
for ii, _ in enumerate(data["atom_names"]):
data["atom_numbs"].append(np.count_nonzero(data["atom_types"] == ii))

return data

Expand Down

0 comments on commit db6298a

Please sign in to comment.