-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example of Gather/Scatter from CrossEntropyLoss #2556
Comments
This is a great example as it's used in any multiclass classification problem or learned compression network. Should we focus on separate forward and backward or also look at the "turnaround" fusion? In this case the backward is just a softmax minus one-hot, and the softmax is much simpler if we hold on to the log_softmax from the forward computation. |
Is it always to a size of 1? Looks like so as it's followed by squeeze. Assuming yes, is it completely unknown which value is gathered? Could it be always, for example, the first value? Asking as a generic implementation of gather would need to have conservative assumptions, for example, in this case the output of log softmax, |
Yeah always size 1, and it can be any of the values 0-32767. The index is the true label for an example, and the tensor we're indexing would be the predicted log probabilities (logits) for each class. |
The fact that it's always size 1 would be valuable for code generation. If we could assume it, we could take a different code generation strategy. That information would need to be communicated to the backend somehow. What would be the best way? It seems to me that if a frontend could translate |
|
Hahaha, turns out I was dumb and totally speaking non-sense with how cross-entropy loss could be done with numpy.take..... 😮💨 Thanks to @jacobhinkle for kindly pointing it out.
Having said that, if I'm reading this ^^^ correctly We don't need to expose this at user facing API. As long as we put this We can hide
I think with this, we should be able to support |
hmmm. wait 1 sec here. are we referring to the existing jacob mentioned it here
so input is of shape [N, C], while index is of shape [N], we want an output of shape [N]. (Note that this is a crossentropy loss actually supports arbitrary rank loss, but that's not really important in this discussion: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). I don't think this is what I know we've had this conversation a few times in separate context and groups. Lol, let's go over this on Monday and make sure that we are on the same page regarding:
|
Oh, I see. I thought the problem was much simpler. Looks like we do need to have a simplified version of
whereas in the real
https://pytorch.org/docs/stable/generated/torch.gather.html It may actually be relatively trivial to lift the second assumption as I think it's just cutting off the final In any case, this isn't as trivial as I initially thought. One question on gathering from |
It might complicate things, but I'll note that this pattern when used in a loss is typically followed immediately by a reduction, so that the combo could be implemented by an iota() + where() composed with a sum. Would that allow us to avoid writing to global memory, and the cooperative launch? |
Probably not. The reason we would need to use global memory is that the input to the gather op may be parallelized by blockIdx and threadIdx. In general, in order to allow ops after gather to be parallelized independently from the ops before the gather, the gather input needs to be accessible by any threads, which means global memory followed. For example:
Let's say
The sizes
The problem here is that Alternatively, the below would work without using global memory:
Note that all dependent ops after |
Interesting. In the special case where the gather output has a single use which is a reduction including the gather axis, rewriting the graph from tv2 = torch_gather(tv1, tv_index, dim); // assume tv_index has size 1 in dimension dim
tv3 = sum(tv2, {dim}); to tv4 = iota(tv1->axis(dim)->extent());
tv5 = broadcast(... // bcast along all dimensions other than dim
tv6 = where(eq(tv5, tv_index), tv1, zeroVal); // forms a "one-hot" TensorView
tv7 = sum(tv6, {dim}); seems like it would compute the right thing even though the reduction is trivial. In that case we know tv6 has only the one use in the following sum, so the dim axis can be parallelized the same way as tv1, as it's just a product of pointwise ops. |
A trivial question regarding this:
vs
Can't the second part be handled with a Slice prior to the gather op? would having |
Oh, that translation is interesting. The final sum reduction would still need to be a grid reduction, even though most of the contribution by each thread should be zero, but it's certainly more efficient than writing to global memory, memory flush, global sync, and reading it again. We should definitely try this formulation as well. |
Ideally, it should not, but in reality I'm not sure. We haven't looked at the performance of these ops, so I'm pretty sure there's a lot to consider. |
🚀 The feature, motivation and pitch
The task is to fuse
log_softmax+gather
. Naoya said it depends on his resize function work. The idea being that the tensor output of log_softmax can be roughly[64, 128, 32768]
which in float is ~1GB. It is expensive to re-read that tensor versus “gathering” it to a size of[64, 128, 1]
which is a trivially sized tensor. There are some people working onindex_select
,gather
, andscatter
but they have been only allowed to fuse them as the first operation of fusion. Thegather
, in this instance, would be at the end of the fusion.CrossEntropyLoss
forward
includes alog_softmax
followed by agather
operation.Is is notably used in NLP networks like Bert as seen here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1139-L1143
Code example:
How to view graph?
This section in particular is the
log_softmax + gather
. This is from printing outtorch.compile
's graph.The text was updated successfully, but these errors were encountered: