Skip to content

Commit

Permalink
Fix torch_dtype check.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Nov 13, 2023
1 parent 9f3cca3 commit 05cffab
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion spacy_llm/models/hf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def __init__(
self._config_init.pop("device")

# Fetch proper torch.dtype, if specified.
if has_torch and self._config_init.get("torch_dtype", "") not in ("", "auto"):
if (
has_torch
and "torch_dtype" in self._config_init
and self._config_init["torch_dtype"] != "auto"
):
try:
self._config_init["torch_dtype"] = getattr(
torch, self._config_init["torch_dtype"]
Expand Down

0 comments on commit 05cffab

Please sign in to comment.