Skip to content

Commit

Permalink
dev(narugo): add supported for vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
HansBug committed Sep 18, 2023
1 parent c8a07a3 commit babbf6f
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 0 deletions.
100 changes: 100 additions & 0 deletions test/torch/funcs/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from unittest import skipUnless

import pytest
import torch
from hbutils.testing import vpip

import treetensor.torch as ttorch
from treetensor.torch import Size


@pytest.fixture()
def treetensor_x():
return ttorch.randn({
'a': (2, 5, 7),
'b': {
'x': (3, 4, 6),
}
})


@pytest.fixture()
def treetensor_y():
return ttorch.randn({
'a': (2, 5, 7),
'b': {
'x': (3, 4, 6),
}
})


@pytest.mark.unittest
class TestTorchTensorWrapper:
@skipUnless(vpip('torch') >= '2', 'Torch 2 required.')
def test_vmap(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
n_pow = torch.vmap(f)
batched_pow = ttorch.vmap(f)
r = batched_pow(treetensor_x, treetensor_y)

assert r.shape == Size({
'a': (2,),
'b': {
'x': (3,)
},
})
assert ttorch.isclose(
r,
ttorch.tensor({
'a': n_pow(treetensor_x.a, treetensor_y.a),
'b': {
'x': n_pow(treetensor_x.b.x, treetensor_y.b.x),
}
})
).all()

@skipUnless(vpip('torch') >= '2', 'Torch 2 required.')
def test_vmap_in_dims(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
n_pow = torch.vmap(f, in_dims=1)
batched_pow = ttorch.vmap(f, in_dims=1)
r = batched_pow(treetensor_x, treetensor_y)

assert r.shape == Size({
'a': (5,),
'b': {
'x': (4,)
},
})
assert ttorch.isclose(
r,
ttorch.tensor({
'a': n_pow(treetensor_x.a, treetensor_y.a),
'b': {
'x': n_pow(treetensor_x.b.x, treetensor_y.b.x),
}
})
).all()

@skipUnless(vpip('torch') >= '2', 'Torch 2 required.')
def test_vmap_nested(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
n_pow = torch.vmap(torch.vmap(f))
batched_pow = ttorch.vmap(ttorch.vmap(f))
r = batched_pow(treetensor_x, treetensor_y)

assert r.shape == Size({
'a': (2, 5),
'b': {
'x': (3, 4)
},
})
assert ttorch.isclose(
r,
ttorch.tensor({
'a': n_pow(treetensor_x.a, treetensor_y.a),
'b': {
'x': n_pow(treetensor_x.b.x, treetensor_y.b.x),
}
})
).all()
3 changes: 3 additions & 0 deletions treetensor/torch/funcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .operation import __all__ as _operation_all
from .reduction import *
from .reduction import __all__ as _reduction_all
from .wrapper import *
from .wrapper import __all__ as _wrapper_all
from ...utils import module_autoremove

__all__ = [
Expand All @@ -24,6 +26,7 @@
*_matrix_all,
*_operation_all,
*_reduction_all,
*_wrapper_all,
]

_current_module = sys.modules[__name__]
Expand Down
14 changes: 14 additions & 0 deletions treetensor/torch/funcs/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import wraps

import torch
from treevalue import func_treelize as original_func_treelize

Expand All @@ -11,3 +13,15 @@
auto_tensor = replaceable_partial(auto_tree, cls=[(torch.is_tensor, Tensor)])
get_func_from_torch = module_func_loader(torch, Tensor,
[(torch.is_tensor, Tensor)])


def wrap_for_treelize(*args, **kwargs):
def _decorator(func):
@wraps(func)
def _new_func(*args_, **kwargs_):
retval = func(*args_, **kwargs_)
return func_treelize(*args, **kwargs)(retval)

return _new_func

return _decorator
13 changes: 13 additions & 0 deletions treetensor/torch/funcs/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch

from .base import doc_from_base, wrap_for_treelize

__all__ = [
'vmap',
]


@doc_from_base()
@wrap_for_treelize()
def vmap(func, *args, **kwargs):
return torch.vmap(func, *args, **kwargs)

0 comments on commit babbf6f

Please sign in to comment.