This repository contains the code for FlashSigmoid approach from the paper: Theory, Analysis, and Best Practices for Sigmoid Self-Attention.
# (Softmax) Attention
out = softmax(q @ k.T / sqrt(d)) @ v
# Sigmoid Attention
out = sigmoid(q @ k.T / sqrt(d) + b) @ v # b: scalar
- FlashSigmoid is motivated by the efficient hardware aware implementation of FlashAttention2.
- We compute
sigmoid(x)
assigmoid(x) = 0.5*(1 + tanh(0.5*x))
and leverage fast tanh primitives. - We remove allocation and computation of unnecessary variables (e.g., row-sum, row-max), which are not needed for sigmoid attention.
Our FlashSigmoid implementation builds on FlashAttention2 at commit 6c9e60de566800538fedad2ad5e6b7b55ca7f0c5
(version 2.5.6).
Subsequently, we inherit the same requirements as FlashAttention2:
- CUDA
11.6
and above. - PyTorch
1.12
and above. - Linux operating system.
Before installation, make sure that:
- PyTorch is installed.
packaging
package is installed. If not, runpip install packaging
.- Make sure that
ninja
package is installed and that it works correctly.- This can be done by checking if
ninja --version
followed byecho $?
should return exit code0
. - Otherwise, reinstall the package as
pip uninstall -y ninja && pip install ninja
. - Without
ninja
, compiling can take a very long time.
- This can be done by checking if
From the ml-sigmoid-attention
directory run the following commands to install FlashSigmoid:
# Create an environment for sigmoid attention, if not done already.
conda create -n sigmoid-attn-py310 python=3.10
conda activate sigmoid-attn-py310
# Remove pre-existing implementation, if any, and install.
cd flash_sigmoid
pip uninstall -y flash_sigmoid
rm -rf build dist flash_sigmoid.egg_info
# Note that if build fails with no apparent cause, try decreasing MAX_JOBS.
# On the other hand, you might want to try a higher value, should your setup support that, to speed-up install process.
MAX_JOBS=8 python3 setup.py install
# You can also run unit tests as follows.
# pytest -k test_flash_attn_output tests/test_flash_attn.py
You can collocate softmax FlashAttention2 at the above commit as well:
# Create an environment for sigmoid attention, if not done already.
conda create -n sigmoid-attn-py310 python=3.10
conda activate sigmoid-attn-py310
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention
git checkout 6c9e60de566800538fedad2ad5e6b7b55ca7f0c5
# Note that if build fails with no apparent cause, try decreasing MAX_JOBS.
# On the other hand, you might want to try a higher value, should your setup support that, to speed-up install process.
MAX_JOBS=8 python3 setup.py install
cd .. && rm -rf flash-attention
# Open the github repo in browser and augment the URL with the following:
# The difference below shows <FlashAttention2> .. <FlashSigmoid>.
https://<github-url-name>/compare/6c9e60de566800538fedad2ad5e6b7b55ca7f0c5..533c2691e05e05899eeaa546e8909f510e9cf657
The usage and signature of flash functions of FlashSigmoid are the same as that of FlashAttention2 except:
- We can pass an optional additional argument
sigmoid_bias: float
to the functions. This argument represents theb
scalar in the defining equation of FlashSigmoid above. If not passed,sigmoid_bias
gets assigned the default value of0
. - We do NOT support
varlen
andkvcache
variants of flash functions. - We do NOT support
dropout_p
and thus,dropout_p
will always be0
.
from flash_sigmoid import flash_attn_func as flash_sigmoid_func
# Batch size: B
# Sequence length: T
# Query heads: H_q
# Feature dimension per head: D
# Key/value heads: H_kv
# q: torch.Tensor with dtype bf16/fp16 and shape: [B, T, H_q, D]
# k: torch.Tensor with dtype bf16/fp16 and shape: [B, T, H_kv, D]
# v: torch.Tensor with dtype bf16/fp16 and shape: [B, T, H_kv, D]
# softmax_scale: Optional[float] that defaults to 1/sqrt(D) if None
# dropout_p: Attention dropout, which is NOT yet supported and is 0 for now.
# window_size: tuple[int, int] showing left and right extremes of windowed attention.
# If we don't want windowed attention, set to (-1, -1).
# alibi_slopes: torch.Tensor with dtype fp32 and shape: [H_q] or [B, H_q].
# causal: bool to indicate whether we want to carry out causal attention.
# sigmoid_bias: float (not trainable) to be added to q @ k.T / sqrt(D).
# out: torch.Tensor with dtype and shape of q: [B, T, H_q, D]
out = flash_sigmoid_func(
q,
k,
v,
softmax_scale,
dropout_p,
window_size,
alibi_slopes,
causal,
sigmoid_bias,
)
- A more detailed single file usage implementation of FlashSigmoid can be found here.
- A more detailed single file usage implementation of FlashAttention2 can be found here.
Forward pass kernels on H100. | Backward pass kernels on H100. |
---|---|
Train losses comparing SigmoidAttn with SoftmaxAttn. |
---|