Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NumPy 2 #202

Merged
merged 18 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ jobs:
pip install uv
uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system
# TODO: remove pin once reverse readline fixed
uv pip install monty==2024.7.12 --system
- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
env:
Expand Down
40 changes: 22 additions & 18 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from chgnet import TrainTask

warnings.filterwarnings("ignore")
datatype = torch.float32
TORCH_DTYPE = torch.float32


class StructureData(Dataset):
Expand Down Expand Up @@ -163,21 +163,21 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:
struct, graph_id=graph_id, mp_id=mp_id
)
targets = {
"e": torch.tensor(self.energies[graph_id], dtype=datatype),
"f": torch.tensor(self.forces[graph_id], dtype=datatype),
"e": torch.tensor(self.energies[graph_id], dtype=TORCH_DTYPE),
"f": torch.tensor(self.forces[graph_id], dtype=TORCH_DTYPE),
}
if self.stresses is not None:
# Convert VASP stress
targets["s"] = torch.tensor(
self.stresses[graph_id], dtype=datatype
self.stresses[graph_id], dtype=TORCH_DTYPE
) * (-0.1)
if self.magmoms is not None:
mag = self.magmoms[graph_id]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE))

return crystal_graph, targets

Expand Down Expand Up @@ -275,18 +275,18 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.data[graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.data[graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.data[graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * -0.1
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * -0.1
elif key == "m":
mag = self.data[graph_id][self.magmom_key]
# use absolute value for magnetic moments
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE))
return crystal_graph, targets

# Omit structures with isolated atoms.
Expand Down Expand Up @@ -404,21 +404,23 @@ def __getitem__(self, idx) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.labels[mp_id][graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.labels[mp_id][graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.labels[mp_id][graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1)
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1)
elif key == "m":
mag = self.labels[mp_id][graph_id][self.magmom_key]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(
torch.tensor(mag, dtype=TORCH_DTYPE)
)
return crystal_graph, targets

# Omit failed structures. Return another randomly selected structure
Expand Down Expand Up @@ -629,21 +631,23 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
for key in self.targets:
if key == "e":
energy = self.data[mp_id][graph_id][self.energy_key]
targets["e"] = torch.tensor(energy, dtype=datatype)
targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE)
elif key == "f":
force = self.data[mp_id][graph_id][self.force_key]
targets["f"] = torch.tensor(force, dtype=datatype)
targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE)
elif key == "s":
stress = self.data[mp_id][graph_id][self.stress_key]
# Convert VASP stress
targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1)
targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1)
elif key == "m":
mag = self.data[mp_id][graph_id][self.magmom_key]
# use absolute value for magnetic moments
if mag is None:
targets["m"] = None
else:
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
targets["m"] = torch.abs(
torch.tensor(mag, dtype=TORCH_DTYPE)
)
return crystal_graph, targets

# Omit structures with isolated atoms. Return another randomly selected
Expand Down Expand Up @@ -773,7 +777,7 @@ def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tens
graphs = [graph for graph, _ in batch_data]
all_targets = {key: [] for key in batch_data[0][1]}
all_targets["e"] = torch.tensor(
[targets["e"] for _, targets in batch_data], dtype=datatype
[targets["e"] for _, targets in batch_data], dtype=TORCH_DTYPE
)

for _, targets in batch_data:
Expand Down
10 changes: 5 additions & 5 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
except (ImportError, AttributeError):
make_graph = None

DATATYPE = torch.float32
TORCH_DTYPE = torch.float32


class CrystalGraphConverter(nn.Module):
Expand Down Expand Up @@ -124,10 +124,10 @@ def forward(
requires_grad=False,
)
atom_frac_coord = torch.tensor(
structure.frac_coords, dtype=DATATYPE, requires_grad=True
structure.frac_coords, dtype=TORCH_DTYPE, requires_grad=True
)
lattice = torch.tensor(
structure.lattice.matrix, dtype=DATATYPE, requires_grad=True
structure.lattice.matrix, dtype=TORCH_DTYPE, requires_grad=True
)
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8
Expand Down Expand Up @@ -177,7 +177,7 @@ def forward(
atomic_number=atomic_number,
atom_frac_coord=atom_frac_coord,
atom_graph=atom_graph,
neighbor_image=torch.tensor(image, dtype=DATATYPE),
neighbor_image=torch.tensor(image, dtype=TORCH_DTYPE),
directed2undirected=directed2undirected,
undirected2directed=undirected2directed,
bond_graph=bond_graph,
Expand Down Expand Up @@ -250,7 +250,7 @@ def _create_graph_fast(
"""
center_index = np.ascontiguousarray(center_index)
neighbor_index = np.ascontiguousarray(neighbor_index)
image = np.ascontiguousarray(image, dtype=np.int_)
image = np.ascontiguousarray(image, dtype=np.int64)
distance = np.ascontiguousarray(distance)
gc_saved = gc.get_threshold()
gc.set_threshold(0)
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

datatype = torch.float32
TORCH_DTYPE = torch.float32


class CrystalGraph:
Expand Down
67 changes: 36 additions & 31 deletions chgnet/graph/cygraph.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,75 @@
# cython: profile=False
# distutils: language = c

import chgnet.graph.graph

import numpy as np
cimport numpy as np

import chgnet.graph.graph

from libc.stdlib cimport free


cdef extern from 'fast_converter_libraries/create_graph.c':
ctypedef struct Node:
long index
np.int64_t index
LongToDirectedEdgeList* neighbors
long num_neighbors
np.int64_t num_neighbors

ctypedef struct NodeIndexPair:
long center
long neighbor
np.int64_t center
np.int64_t neighbor

ctypedef struct UndirectedEdge:
NodeIndexPair nodes
long index
long* directed_edge_indices
long num_directed_edges
double distance
np.int64_t index
np.int64_t* directed_edge_indices
np.int64_t num_directed_edges
np.float64_t distance

ctypedef struct DirectedEdge:
NodeIndexPair nodes
long index
const long* image
long undirected_edge_index
double distance
np.int64_t index
const np.int64_t* image
np.int64_t undirected_edge_index
np.float64_t distance

ctypedef struct LongToDirectedEdgeList:
long key
np.int64_t key
DirectedEdge** directed_edges_list
int num_directed_edges_in_group

ctypedef struct ReturnElems2:
long num_nodes
long num_directed_edges
long num_undirected_edges
np.int64_t num_nodes
np.int64_t num_directed_edges
np.int64_t num_undirected_edges
Node* nodes
UndirectedEdge** undirected_edges_list
DirectedEdge** directed_edges_list

ReturnElems2* create_graph(
long* center_index,
long n_e,
long* neighbor_index,
long* image,
double* distance,
long num_atoms)
np.int64_t* center_index,
np.int64_t n_e,
np.int64_t* neighbor_index,
np.int64_t* image,
np.float64_t* distance,
np.int64_t num_atoms)

void free_LongToDirectedEdgeList_in_nodes(Node* nodes, long num_nodes)
void free_LongToDirectedEdgeList_in_nodes(Node* nodes, np.int64_t num_nodes)


LongToDirectedEdgeList** get_neighbors(Node* node)

def make_graph(
const long[::1] center_index,
const long n_e,
const long[::1] neighbor_index,
const long[:, ::1] image,
const double[::1] distance,
const long num_atoms
const np.int64_t[::1] center_index,
const np.int64_t n_e,
const np.int64_t[::1] neighbor_index,
const np.int64_t[:, ::1] image,
const np.float64_t[::1] distance,
const np.int64_t num_atoms
):
cdef ReturnElems2* returned
returned = <ReturnElems2*> create_graph(<long*> &center_index[0], n_e, <long*> &neighbor_index[0], <long*> &image[0][0], <double*> &distance[0], num_atoms)
returned = <ReturnElems2*> create_graph(<np.int64_t*> &center_index[0], n_e, <np.int64_t*> &neighbor_index[0], <np.int64_t*> &image[0][0], <np.float64_t*> &distance[0], num_atoms)

chg_DirectedEdge = chgnet.graph.graph.DirectedEdge
chg_Node = chgnet.graph.graph.Node
Expand Down
Loading