From f0ffabe127180a56cb3b10c940e42eb89749de94 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 20:48:05 +0800 Subject: [PATCH] dev(narugo): fix bug on torch 1.x --- treetensor/torch/funcs/base.py | 3 +++ treetensor/torch/funcs/wrapper.py | 22 ++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index 3b5c78fbaa..1bf21a72dc 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -1,6 +1,7 @@ from functools import wraps import torch +from hbutils.testing import vpip from treevalue import func_treelize as original_func_treelize from ..tensor import Tensor @@ -14,6 +15,8 @@ get_func_from_torch = module_func_loader(torch, Tensor, [(torch.is_tensor, Tensor)]) +_is_torch_2 = vpip('torch') >= '2' + def wrap_for_treelize(*args, **kwargs): def _decorator(func): diff --git a/treetensor/torch/funcs/wrapper.py b/treetensor/torch/funcs/wrapper.py index 8fdf5f131b..2f442c6d99 100644 --- a/treetensor/torch/funcs/wrapper.py +++ b/treetensor/torch/funcs/wrapper.py @@ -1,19 +1,21 @@ import torch -from hbutils.testing import vpip -from .base import doc_from_base, wrap_for_treelize +from .base import doc_from_base, wrap_for_treelize, _is_torch_2 __all__ = [ 'vmap', ] -_is_torch_2 = vpip('torch') >= '2' - - -@doc_from_base() -@wrap_for_treelize() -def vmap(func, *args, **kwargs): - if _is_torch_2: +if _is_torch_2: + @doc_from_base() + @wrap_for_treelize() + def vmap(func, *args, **kwargs): return torch.vmap(func, *args, **kwargs) - else: + +else: + def vmap(func, *args, **kwargs): + """ + .. warning: + :method:`treetensor.torch.vmap` is not supported for torch 1.x. + """ raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.')