Skip to content

Commit

Permalink
fix test_trainer_composition_model
Browse files Browse the repository at this point in the history
    def test_trainer_composition_model(tmp_path: Path) -> None:
        for param in chgnet.composition_model.parameters():
            assert param.requires_grad is False
        trainer = Trainer(
            model=chgnet,
            targets="efsm",
            optimizer="Adam",
            criterion="MSE",
            learning_rate=1e-2,
            epochs=5,
        )
        initial_weights = chgnet.composition_model.state_dict()["fc.weight"].clone()
>       trainer.train(
            train_loader, val_loader, save_dir=tmp_path, train_composition_model=True
        )

tests/test_trainer.py:106:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
chgnet/trainer/trainer.py:305: in train
    train_mae = self._train(train_loader, epoch, wandb_log_freq)
chgnet/trainer/trainer.py:400: in _train
    combined_loss = self.criterion(targets, prediction)
../../.venv/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.venv/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
chgnet/trainer/trainer.py:861: in forward
    if mag_target is not None and not np.isnan(mag_target).any():
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = tensor([0.9620, 0.0657], device='mps:0'), dtype = None

    def __array__(self, dtype=None):
        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
        if dtype is None:
>           return self.numpy()
E           TypeError: can't convert mps:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

../../.venv/py312/lib/python3.12/site-packages/torch/_tensor.py:1083: TypeError
  • Loading branch information
janosh committed Nov 16, 2024
1 parent 0cb93e6 commit 1f29985
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
4 changes: 1 addition & 3 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,7 @@ def calculate(
)

# Convert Result
extensive_factor = (
1 if not self.model.is_intensive else structure.composition.num_atoms
)
extensive_factor = len(structure) if self.model.is_intensive else 1
key_map = dict(
e=("energy", extensive_factor),
f=("forces", 1),
Expand Down
2 changes: 1 addition & 1 deletion chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def forward(
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
# exclude structures without magmom labels
if self.allow_missing_labels:
if mag_target is not None and not np.isnan(mag_target).any():
if mag_target is not None and not torch.isnan(mag_target).any():
mag_preds.append(mag_pred)
mag_targets.append(mag_target)
m_mae_size += mag_target.shape[0]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
assert tmp_path.is_dir(), "Training dir was not created"
for target_str in ["e", "f", "s", "m"]:
assert ~np.isnan(trainer.training_history[target_str]["train"]).any()
assert ~np.isnan(trainer.training_history[target_str]["val"]).any()
for prop in "efsm":
assert ~np.isnan(trainer.training_history[prop]["train"]).any()
assert ~np.isnan(trainer.training_history[prop]["val"]).any()
output_files = [file.name for file in tmp_path.iterdir()]
for prefix in ("epoch", "bestE_", "bestF_"):
n_matches = sum(file.startswith(prefix) for file in output_files)
Expand Down

0 comments on commit 1f29985

Please sign in to comment.