-
Notifications
You must be signed in to change notification settings - Fork 17
/
utils.h
43 lines (38 loc) · 1.5 KB
/
utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#define check_tensors(L, a, b) \
do { \
if ((a)->nDimension != (b)->nDimension) \
return LUA_HANDLE_ERROR_STR((L), "different tensor dimensions"); \
for (int __local__var = 0; __local__var < (a)->nDimension; __local__var++) \
if ((a)->size[__local__var] != (b)->size[__local__var]) \
return LUA_HANDLE_ERROR_STR((L), "different tensor sizes"); \
} while (0)
#define check_tensor(L, t, type) \
do { \
if (!type##_isContiguous(t)) \
return LUA_HANDLE_ERROR_STR((L), "tensor should be contiguous"); \
} while (0)
#define get_tensor_size(t, type) \
(TH##type##Tensor_nElement(t))
#define get_tensor(L, idx, type) \
(TH##type##Tensor *)luaT_checkudata(L, idx, "torch." #type "Tensor")
static int push_table_contents(lua_State *L, int arg)
{
int size = 0;
while(1) {
lua_checkstack(L, 1);
lua_rawgeti(L, arg, size + 1);
if (lua_isnil(L, -1)) {
lua_pop(L, 1);
break;
}
size++;
}
return size;
}
#define verify_push_table_contents(L, idx, count) do { \
int __tmp_count = push_table_contents(L, idx); \
if (__tmp_count != count) { \
lua_pop(L, __tmp_count); \
LUA_HANDLE_ERROR_STR(L, "Table sizes do not match"); \
} \
} while(0)