-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast implementation of Select for most cases on CPU #687
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -654,40 +654,63 @@ void PasteCols(Tensor out_, | |
} | ||
} | ||
|
||
#if 0 // this version seems to actually be buggy, but also not used in decoding? | ||
// Optimized version of Select for axis=2 | ||
// @TODO: make this generally fast without this special version | ||
void SelectAxis2(Tensor out, | ||
const Tensor in, | ||
const Tensor indices) { | ||
|
||
std::cerr << indices->debug() << std::endl; | ||
|
||
matchOrAbort<IndexType>(indices->type()); | ||
|
||
functional::Shape outShape = out->shape(); | ||
functional::Shape inShape = in->shape(); | ||
|
||
auto idxData = indices->data<IndexType>(); | ||
auto odata = out->data(); | ||
const auto idata = in->data(); | ||
|
||
int size = outShape[3]; | ||
|
||
for(int k = 0; k < outShape[0]; ++k) { | ||
for(int j = 0; j < outShape[1]; ++j) { | ||
int outOffset = k * j * outShape[2] * size + j * outShape[2] * size; | ||
int inOffset = k * j * inShape[2] * size + j * inShape[2] * size; | ||
for(int i = 0; i < outShape[2]; ++i) { | ||
auto idx = idxData[i]; | ||
int outIndex = outOffset + i * size; | ||
int inIndex = inOffset + idx * size; | ||
std::copy(idata + inIndex, idata + inIndex + size, odata + outIndex); | ||
/* Recursive template to implement LoopBeforeAxis. */ | ||
template <class Backend, int Before> struct LoopBeforeAxisImpl { | ||
static inline void Loop( | ||
const functional::Shape &outShape, int outBase, | ||
const functional::Shape &inShape, int inBase, | ||
const functional::Shape &idxShape, int idxBase, | ||
int axisCPU, | ||
Backend backend) { | ||
// Loop over this dimension. | ||
const int dim = axisCPU - Before; | ||
if (dim < 0) { | ||
// This template is instantiated for every possible dimension, typically | ||
// more than before the axis. | ||
LoopBeforeAxisImpl<Backend, Before - 1>::Loop(outShape, outBase, inShape, inBase, idxShape, idxBase, axisCPU, backend); | ||
} else { | ||
const int outStride = outShape.stride(dim); | ||
const int end = outShape.dim(dim); | ||
const int inStride = inShape.stride(dim); | ||
const int idxStride = idxShape.bstride(dim); | ||
for (int i = 0; i < end; ++i) { | ||
LoopBeforeAxisImpl<Backend, Before - 1>::Loop(outShape, outBase, inShape, inBase, idxShape, idxBase, axisCPU, backend); | ||
outBase += outStride; | ||
inBase += inStride; | ||
idxBase += idxStride; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
/* We're at the axis, call the functor. */ | ||
template <class Backend> struct LoopBeforeAxisImpl<Backend, 0> { | ||
static inline void Loop( | ||
const functional::Shape &, int outBase, | ||
const functional::Shape &, int inBase, | ||
const functional::Shape &, int idxBase, | ||
int /*axisCPU*/, | ||
Backend backend) { | ||
backend(outBase, inBase, idxBase); | ||
} | ||
}; | ||
|
||
/* Jointly loop over dimensions [0, axisCPU) of three tensors out, in, and | ||
* indices. Call the Backend functor for each iteration of the loop. | ||
* Backend is a functor taking the tensors and base indices into them: | ||
* Backend::operator()( | ||
* int out_base, | ||
* int in_base, | ||
* int indices_base); | ||
*/ | ||
template <class Backend> inline void LoopBeforeAxis( | ||
const functional::Shape &outShape, | ||
const functional::Shape &inShape, | ||
const functional::Shape &idxShape, | ||
int axisCPU, | ||
Backend backend) { | ||
LoopBeforeAxisImpl<Backend, functional::Shape::size()>::Loop(outShape, 0, inShape, 0, idxShape, 0, axisCPU, backend); | ||
} | ||
#endif | ||
|
||
void Select(Tensor out, | ||
const Tensor in, | ||
|
@@ -696,19 +719,50 @@ void Select(Tensor out, | |
|
||
matchOrAbort<IndexType>(indices->type()); | ||
|
||
// @TODO: make this efficient | ||
functional::Shape outShape = out->shape(); | ||
functional::Shape inShape = in->shape(); | ||
functional::Shape idxShape = indices->shape(); | ||
int length = outShape.elements(); | ||
|
||
functional::Array<int, functional::Shape::size()> dims; | ||
int axisCPU = (int)(axis + functional::Shape::size() - out->shape().size()); | ||
|
||
#if 0 // buggy but not really used? | ||
if(axisCPU == 2 && outShape == idxShape) // specialization for axis==2 when there is no broadcasting, @TODO to be removed once we have a faster implementation below | ||
return SelectAxis2(out, in, indices); | ||
#endif | ||
// Are all index dimensions 1 after the axis? | ||
bool flatIndices = true; | ||
// Total dimensionality of input and output after the axis. | ||
int afterAxis = 1; | ||
for (int i = axisCPU + 1; i < functional::Shape::size(); ++i) { | ||
afterAxis *= outShape[i]; | ||
if (idxShape[i] != 1) { | ||
flatIndices = false; | ||
} | ||
} | ||
/* Faster version based on copying. Requirements: | ||
* input is contiguous for every dimension after the axis. | ||
* output is contiguous for every dimension after the axis. | ||
* indices have shape 1 for every dimension after the axis. | ||
*/ | ||
if (afterAxis == inShape.stride(axisCPU) && afterAxis == outShape.stride(axisCPU) && flatIndices) { | ||
const int end = outShape.dim(axisCPU); | ||
const int outStride = outShape.stride(axisCPU); | ||
const int idxStride = idxShape.bstride(axisCPU); | ||
// Loop over all dimensions before the axis. | ||
LoopBeforeAxis(outShape, inShape, idxShape, axisCPU, | ||
[out, in, indices, afterAxis, end, outStride, idxStride](int outBase, int inBase, int idxBase) { | ||
// Loop over the axis dimension. | ||
for (int i = 0; i < end; ++i) { | ||
int index = indices->data<IndexType>()[idxBase]; | ||
// Loop over all dimensions after the axis. | ||
std::copy(in->data() + inBase + index * afterAxis, in->data() + inBase + index * afterAxis + afterAxis, out->data() + outBase); | ||
outBase += outStride; | ||
idxBase += idxStride; | ||
} | ||
}); | ||
return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's do an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe let's also add a positive and a negative example where each branch is used? There are a couple of them in tests_operators, I think at least one will still go to the default version? |
||
} | ||
|
||
// @TODO: make this efficient | ||
int length = outShape.elements(); | ||
// Loop over outer dimensions (those before the axis). | ||
functional::Array<int, functional::Shape::size()> dims; | ||
|
||
for(int index = 0; index < length; ++index) { | ||
outShape.dims(index, dims); // compute dimension-based indices from global index; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use a different name than
Backend
. We use backend for CPU/GPU device things, had me confused for a sec.