Skip to content

Commit

Permalink
dev(narugo): add check for torch 2
Browse files Browse the repository at this point in the history
  • Loading branch information
HansBug committed Sep 18, 2023
1 parent b298918 commit b539dc7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ 3.7 ]
python-version: [ 3.8 ]

services:
plantuml:
Expand Down
6 changes: 6 additions & 0 deletions test/torch/funcs/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,9 @@ def test_vmap_nested(self, treetensor_x, treetensor_y):
}
})
).all()

@skipUnless(vpip('torch') < '2', 'Torch 1.x required.')
def test_vmap_torch_1x(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
with pytest.raises(NotImplementedError):
_ = ttorch.vmap(f)
8 changes: 7 additions & 1 deletion treetensor/torch/funcs/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import torch
from hbutils.testing import vpip

from .base import doc_from_base, wrap_for_treelize

__all__ = [
'vmap',
]

_is_torch_2 = vpip('torch') >= '2'


@doc_from_base()
@wrap_for_treelize()
def vmap(func, *args, **kwargs):
return torch.vmap(func, *args, **kwargs)
if _is_torch_2:
return torch.vmap(func, *args, **kwargs)
else:
raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.')

Check warning on line 19 in treetensor/torch/funcs/wrapper.py

View check run for this annotation

Codecov / codecov/patch

treetensor/torch/funcs/wrapper.py#L19

Added line #L19 was not covered by tests

0 comments on commit b539dc7

Please sign in to comment.