From babbf6f5fc0b29f77f80fffe37c0d448c0048d9c Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 17:00:42 +0800 Subject: [PATCH] dev(narugo): add supported for vmap --- test/torch/funcs/test_wrapper.py | 100 +++++++++++++++++++++++++++++ treetensor/torch/funcs/__init__.py | 3 + treetensor/torch/funcs/base.py | 14 ++++ treetensor/torch/funcs/wrapper.py | 13 ++++ 4 files changed, 130 insertions(+) create mode 100644 test/torch/funcs/test_wrapper.py create mode 100644 treetensor/torch/funcs/wrapper.py diff --git a/test/torch/funcs/test_wrapper.py b/test/torch/funcs/test_wrapper.py new file mode 100644 index 0000000000..4a02f9114c --- /dev/null +++ b/test/torch/funcs/test_wrapper.py @@ -0,0 +1,100 @@ +from unittest import skipUnless + +import pytest +import torch +from hbutils.testing import vpip + +import treetensor.torch as ttorch +from treetensor.torch import Size + + +@pytest.fixture() +def treetensor_x(): + return ttorch.randn({ + 'a': (2, 5, 7), + 'b': { + 'x': (3, 4, 6), + } + }) + + +@pytest.fixture() +def treetensor_y(): + return ttorch.randn({ + 'a': (2, 5, 7), + 'b': { + 'x': (3, 4, 6), + } + }) + + +@pytest.mark.unittest +class TestTorchTensorWrapper: + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + n_pow = torch.vmap(f) + batched_pow = ttorch.vmap(f) + r = batched_pow(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (2,), + 'b': { + 'x': (3,) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() + + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap_in_dims(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + n_pow = torch.vmap(f, in_dims=1) + batched_pow = ttorch.vmap(f, in_dims=1) + r = batched_pow(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (5,), + 'b': { + 'x': (4,) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() + + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap_nested(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + n_pow = torch.vmap(torch.vmap(f)) + batched_pow = ttorch.vmap(ttorch.vmap(f)) + r = batched_pow(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (2, 5), + 'b': { + 'x': (3, 4) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() diff --git a/treetensor/torch/funcs/__init__.py b/treetensor/torch/funcs/__init__.py index 98b029bf89..51a4c89230 100644 --- a/treetensor/torch/funcs/__init__.py +++ b/treetensor/torch/funcs/__init__.py @@ -14,6 +14,8 @@ from .operation import __all__ as _operation_all from .reduction import * from .reduction import __all__ as _reduction_all +from .wrapper import * +from .wrapper import __all__ as _wrapper_all from ...utils import module_autoremove __all__ = [ @@ -24,6 +26,7 @@ *_matrix_all, *_operation_all, *_reduction_all, + *_wrapper_all, ] _current_module = sys.modules[__name__] diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index d5ee52a939..3b5c78fbaa 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -1,3 +1,5 @@ +from functools import wraps + import torch from treevalue import func_treelize as original_func_treelize @@ -11,3 +13,15 @@ auto_tensor = replaceable_partial(auto_tree, cls=[(torch.is_tensor, Tensor)]) get_func_from_torch = module_func_loader(torch, Tensor, [(torch.is_tensor, Tensor)]) + + +def wrap_for_treelize(*args, **kwargs): + def _decorator(func): + @wraps(func) + def _new_func(*args_, **kwargs_): + retval = func(*args_, **kwargs_) + return func_treelize(*args, **kwargs)(retval) + + return _new_func + + return _decorator diff --git a/treetensor/torch/funcs/wrapper.py b/treetensor/torch/funcs/wrapper.py new file mode 100644 index 0000000000..cb33bc5270 --- /dev/null +++ b/treetensor/torch/funcs/wrapper.py @@ -0,0 +1,13 @@ +import torch + +from .base import doc_from_base, wrap_for_treelize + +__all__ = [ + 'vmap', +] + + +@doc_from_base() +@wrap_for_treelize() +def vmap(func, *args, **kwargs): + return torch.vmap(func, *args, **kwargs)