diff --git a/dpdata/lammps/dump.py b/dpdata/lammps/dump.py index 0d29828a..c82b01c7 100644 --- a/dpdata/lammps/dump.py +++ b/dpdata/lammps/dump.py @@ -196,29 +196,31 @@ def load_file(fname: FileType, begin=0, step=1): if cc >= begin and (cc - begin) % step == 0: buff.append(line) - def get_spin_keys(inputfile): - # raed input file and get the keys for spin info in dump - # 1. find the index of "compute X X X sp spx spy spz ..." in the input file - # 2. construct the keys for spin info: c_X[idx],... - if inputfile is not None and os.path.isfile(inputfile): - with open(inputfile) as f: - lines = f.readlines() - for line in lines: + """ + Read input file and get the keys for spin info in dump. + + Parameters: + ----------- + inputfile : str + Path to the input file. + + Returns: + -------- + list or None + List of spin info keys if found, None otherwise. + """ + if inputfile is None or not os.path.isfile(inputfile): + return None + + with open(inputfile) as f: + for line in f.readlines(): ls = line.split() - if ( - len(ls) > 7 - and ls[0] == "compute" - and "sp" in ls - and "spx" in ls - and "spy" in ls - and "spz" in ls - ): - idx_sp = "c_" + ls[1] + "[" + str(ls.index("sp") - 3) + "]" - idx_spx = "c_" + ls[1] + "[" + str(ls.index("spx") - 3) + "]" - idx_spy = "c_" + ls[1] + "[" + str(ls.index("spy") - 3) + "]" - idx_spz = "c_" + ls[1] + "[" + str(ls.index("spz") - 3) + "]" - return [idx_sp, idx_spx, idx_spy, idx_spz] + if (len(ls) > 7 and ls[0] == "compute" and + all(key in ls for key in ["sp", "spx", "spy", "spz"])): + compute_name = ls[1] + return [f"c_{compute_name}[{ls.index(key) - 3}]" for key in ["sp", "spx", "spy", "spz"]] + return None @@ -235,39 +237,38 @@ def get_spin(lines, spin_keys): the spin info is stored in sp, spx, spy, spz or spin_keys, which is the spin norm and the spin vector 1 1 0.00141160 5.64868599 0.01005602 1.54706291 0.00000000 0.00000000 1.00000000 -1.40772100 -2.03739417 -1522.64797384 -0.00397809 -0.00190426 -0.00743976 """ + blk, head = _get_block(lines, "ATOMS") heads = head.split() key1 = ["sp", "spx", "spy", "spz"] - # check if head contains spin info if all(i in heads for i in key1): key = key1 elif spin_keys is not None and all(i in heads for i in spin_keys): key = spin_keys else: return None - idx_id = heads.index("id") - 2 - idx_sp = heads.index(key[0]) - 2 - idx_spx = heads.index(key[1]) - 2 - idx_spy = heads.index(key[2]) - 2 - idx_spz = heads.index(key[3]) - 2 - - norm = [] - vec = [] - id = [] - for ii in blk: - words = ii.split() - norm.append([float(words[idx_sp])]) - vec.append( - [float(words[idx_spx]), float(words[idx_spy]), float(words[idx_spz])] - ) - id.append(int(words[idx_id])) - - spin = np.array(norm) * np.array(vec) - id, spin = zip(*sorted(zip(id, spin))) - return np.array(spin) + try: + idx_id = heads.index("id") - 2 + idx_sp, idx_spx, idx_spy, idx_spz = [heads.index(k) - 2 for k in key] + + norm = [] + vec = [] + atom_ids = [] + for line in blk: + words = line.split() + norm.append([float(words[idx_sp])]) + vec.append([float(words[idx_spx]), float(words[idx_spy]), float(words[idx_spz])]) + atom_ids.append(int(words[idx_id])) + + spin = np.array(norm) * np.array(vec) + atom_ids, spin = zip(*sorted(zip(atom_ids, spin))) + return np.array(spin) + except (ValueError, IndexError) as e: + warnings.warn(f"Error processing spin data: {str(e)}") + return None def system_data( lines, type_map=None, type_idx_zero=True, unwrap=False, input_name=None