Skip to content

Commit

Permalink
Fix xformers to work on v0.0.22 - 0.0.25 (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonluca authored Mar 27, 2024
1 parent 5e37d6b commit a5255dc
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/sfast/libs/xformers/xformers_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
from xformers import ops
from sfast.utils.custom_python_operator import register_custom_python_operator

OP_STR_MAP = {
ops.MemoryEfficientAttentionCutlassFwdFlashBwOp:
OP_STR_MAP = {}

for attr_name in [
'MemoryEfficientAttentionCutlassFwdFlashBwOp',
ops.MemoryEfficientAttentionCutlassOp: 'MemoryEfficientAttentionCutlassOp',
ops.MemoryEfficientAttentionFlashAttentionOp:
'MemoryEfficientAttentionCutlassOp',
'MemoryEfficientAttentionFlashAttentionOp',
ops.MemoryEfficientAttentionOp: 'MemoryEfficientAttentionOp',
ops.MemoryEfficientAttentionTritonFwdFlashBwOp:
'MemoryEfficientAttentionOp',
'MemoryEfficientAttentionTritonFwdFlashBwOp',
ops.TritonFlashAttentionOp: 'TritonFlashAttentionOp',
}
'TritonFlashAttentionOp',
'MemoryEfficientAttentionCkOp',
'MemoryEfficientAttentionSplitKCkOp'
]:
op_attr = getattr(ops, attr_name, None)
if op_attr is not None:
OP_STR_MAP[op_attr] = attr_name

STR_OP_MAP = {v: k for k, v in OP_STR_MAP.items()}

Expand Down

0 comments on commit a5255dc

Please sign in to comment.