-
Notifications
You must be signed in to change notification settings - Fork 226
/
numpyarray.h
107 lines (104 loc) · 2.84 KB
/
numpyarray.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#ifndef clstm_numpyarray_
#define clstm_numpyarray_
#include "numpy/arrayobject.h"
template <class T, int TYPENUM>
struct NumPyArray {
PyArrayObject *obj = 0;
NumPyArray() {}
NumPyArray(PyObject *object_) {
if (!object_) throw "null pointer";
if (!PyArray_Check(object_)) throw "expected a numpy array";
obj = (PyArrayObject *)object_;
Py_INCREF(obj);
valid();
}
NumPyArray(NumPyArray<T, TYPENUM> &other) {
Py_INCREF(other.obj);
Py_DECREF(obj);
obj = other.obj;
}
NumPyArray(int d0, int d1 = 0, int d2 = 0, int d3 = 0) {
npy_intp ndims[] = {d0, d1, d2, d3, 0};
int rank = 0;
while (ndims[rank]) rank++;
obj = PyArray_SimpleNew(rank, ndims, TYPENUM);
valid();
}
~NumPyArray() {
Py_DECREF(obj);
obj = 0;
}
void operator=(NumPyArray<T, TYPENUM> &other) {
Py_INCREF(other.obj);
Py_DECREF(obj);
obj = other.obj;
}
void valid() {
if (!obj) throw "no array set";
if (PyArray_TYPE(obj) != TYPENUM) throw "wrong numpy array type";
if ((PyArray_FLAGS(obj) & NPY_ARRAY_C_CONTIGUOUS) == 0)
throw "expected contiguous array";
}
int rank() {
valid();
return PyArray_NDIM(obj);
}
int dim(int i) {
valid();
return PyArray_DIM(obj, i);
}
int size() {
valid();
return PyArray_SIZE(obj);
}
void resize(int d0, int d1 = 0, int d2 = 0, int d3 = 0) {
npy_intp ndims[] = {d0, d1, d2, d3, 0};
int rank = 0;
while (ndims[rank]) rank++;
PyArray_Dims dims = {ndims, rank};
if (PyArray_Resize(obj, &dims, 0, NPY_CORDER) == nullptr)
throw "resize failed";
}
T &operator()(int i) {
assert(rank() == 1);
assert(unsigned(i) < unsigned(dim(0)));
T *data = (T *)PyArray_DATA(obj);
return data[i];
}
T &operator()(int i, int j) {
assert(rank() == 2);
assert(unsigned(i) < unsigned(dim(0)));
assert(unsigned(j) < unsigned(dim(1)));
T *data = (T *)PyArray_DATA(obj);
return data[i * dim(1) + j];
}
T &operator()(int i, int j, int k) {
assert(rank() == 3);
assert(unsigned(i) < unsigned(dim(0)));
assert(unsigned(j) < unsigned(dim(1)));
assert(unsigned(k) < unsigned(dim(2)));
T *data = (T *)PyArray_DATA(obj);
return data[(i * dim(1) + j) * dim(2) + k];
}
T &operator()(int i, int j, int k, int l) {
assert(rank() == 4);
assert(unsigned(i) < unsigned(dim(0)));
assert(unsigned(j) < unsigned(dim(1)));
assert(unsigned(k) < unsigned(dim(2)));
assert(unsigned(l) < unsigned(dim(3)));
T *data = (T *)PyArray_DATA(obj);
return data[((i * dim(1) + j) * dim(2) + k) * dim(3) + l];
}
T *data() {
valid();
return (T *)PyArray_DATA(obj);
}
void copyTo(T *dest) {
valid();
T *data = (T *)PyArray_DATA(obj);
int N = size();
for (int i = 0; i < N; i++) dest[i] = data[i];
}
};
typedef NumPyArray<float, NPY_FLOAT> npa_float;
#endif