Skip to content

Commit

Permalink
improve codes
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Oct 17, 2024
1 parent ad5808a commit daa38ba
Showing 1 changed file with 43 additions and 42 deletions.
85 changes: 43 additions & 42 deletions dpdata/lammps/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,29 +196,31 @@ def load_file(fname: FileType, begin=0, step=1):
if cc >= begin and (cc - begin) % step == 0:
buff.append(line)


def get_spin_keys(inputfile):
# raed input file and get the keys for spin info in dump
# 1. find the index of "compute X X X sp spx spy spz ..." in the input file
# 2. construct the keys for spin info: c_X[idx],...
if inputfile is not None and os.path.isfile(inputfile):
with open(inputfile) as f:
lines = f.readlines()
for line in lines:
"""
Read input file and get the keys for spin info in dump.
Parameters:
-----------
inputfile : str
Path to the input file.
Returns:
--------
list or None
List of spin info keys if found, None otherwise.
"""
if inputfile is None or not os.path.isfile(inputfile):
return None

with open(inputfile) as f:
for line in f.readlines():
ls = line.split()
if (
len(ls) > 7
and ls[0] == "compute"
and "sp" in ls
and "spx" in ls
and "spy" in ls
and "spz" in ls
):
idx_sp = "c_" + ls[1] + "[" + str(ls.index("sp") - 3) + "]"
idx_spx = "c_" + ls[1] + "[" + str(ls.index("spx") - 3) + "]"
idx_spy = "c_" + ls[1] + "[" + str(ls.index("spy") - 3) + "]"
idx_spz = "c_" + ls[1] + "[" + str(ls.index("spz") - 3) + "]"
return [idx_sp, idx_spx, idx_spy, idx_spz]
if (len(ls) > 7 and ls[0] == "compute" and
all(key in ls for key in ["sp", "spx", "spy", "spz"])):
compute_name = ls[1]
return [f"c_{compute_name}[{ls.index(key) - 3}]" for key in ["sp", "spx", "spy", "spz"]]

return None


Expand All @@ -235,39 +237,38 @@ def get_spin(lines, spin_keys):
the spin info is stored in sp, spx, spy, spz or spin_keys, which is the spin norm and the spin vector
1 1 0.00141160 5.64868599 0.01005602 1.54706291 0.00000000 0.00000000 1.00000000 -1.40772100 -2.03739417 -1522.64797384 -0.00397809 -0.00190426 -0.00743976
"""

blk, head = _get_block(lines, "ATOMS")
heads = head.split()

key1 = ["sp", "spx", "spy", "spz"]

# check if head contains spin info
if all(i in heads for i in key1):
key = key1
elif spin_keys is not None and all(i in heads for i in spin_keys):
key = spin_keys
else:
return None
idx_id = heads.index("id") - 2
idx_sp = heads.index(key[0]) - 2
idx_spx = heads.index(key[1]) - 2
idx_spy = heads.index(key[2]) - 2
idx_spz = heads.index(key[3]) - 2

norm = []
vec = []
id = []
for ii in blk:
words = ii.split()
norm.append([float(words[idx_sp])])
vec.append(
[float(words[idx_spx]), float(words[idx_spy]), float(words[idx_spz])]
)
id.append(int(words[idx_id]))

spin = np.array(norm) * np.array(vec)
id, spin = zip(*sorted(zip(id, spin)))
return np.array(spin)

try:
idx_id = heads.index("id") - 2
idx_sp, idx_spx, idx_spy, idx_spz = [heads.index(k) - 2 for k in key]

norm = []
vec = []
atom_ids = []
for line in blk:
words = line.split()
norm.append([float(words[idx_sp])])
vec.append([float(words[idx_spx]), float(words[idx_spy]), float(words[idx_spz])])
atom_ids.append(int(words[idx_id]))

spin = np.array(norm) * np.array(vec)
atom_ids, spin = zip(*sorted(zip(atom_ids, spin)))
return np.array(spin)
except (ValueError, IndexError) as e:
warnings.warn(f"Error processing spin data: {str(e)}")
return None

def system_data(
lines, type_map=None, type_idx_zero=True, unwrap=False, input_name=None
Expand Down

0 comments on commit daa38ba

Please sign in to comment.