Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExecuTorch][BE] Split kv cache and SDPA for better code sharing #7413

Open
wants to merge 3 commits into
base: gh/kimishpatel/149/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()

builder_exported.run_canonical_optimizations()

if args.export_only:
exit()

Expand Down
47 changes: 17 additions & 30 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,16 @@ def __init__(
max_seq_length: int,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
self.is_transposed = transpose_cache
if transpose_cache:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.transpose_cache = transpose_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
Expand All @@ -259,12 +253,12 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
# input_pos: [S], k_val: [B, H, S, D]
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
dim_to_slice = 2 if self.transpose_cache else 1
dim_to_slice = 2
seq_length = k_val.size(dim_to_slice)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
Expand All @@ -283,28 +277,22 @@ def update(
else:
k_out = self.k_cache
v_out = self.v_cache
if self.transpose_cache:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class SDPA(nn.Module):
def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
Expand All @@ -314,18 +302,16 @@ def __init__(
def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
bsz,
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
# TODO(kimishpatel): Move this slicing logic to Attention block so that
# SDPA does not have to take input_pos as arg
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
Expand All @@ -336,6 +322,8 @@ def forward(
else:
attn_mask = mask[None, None, input_pos]

# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
# can natively support GQA now. But needs enable_gqa=True
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
Expand Down Expand Up @@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
Expand All @@ -414,15 +400,16 @@ def forward(
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_kv_cache:
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# grouped multiquery attention: expand out keys and values
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down
145 changes: 42 additions & 103 deletions examples/models/llama/source_transformation/quantized_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def __init__(
n_heads,
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
tranposed=False,
enable_dynamic_shape=False,
):
super().__init__()
if cache_type not in (
Expand All @@ -52,14 +50,8 @@ def __init__(
# For now supporting int8 only
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
self.is_transposed = tranposed
self.enable_dynamic_shape = enable_dynamic_shape
if self.is_transposed:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
Expand Down Expand Up @@ -98,71 +90,37 @@ def _quantize(self, value):
return quantized_value, scales, zero_points

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
# quantize current k_val and store it in the cache
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)

quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

if self.is_transposed:
# We cannot use update_cache op at the moment
# if the cache is transposed
# Also note that we shold not need separate paths
# for dynamic shape vs !
# Only reason it is done this way is to accommodate
# for lowering pains of backends that work better
# with index_put op.
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k_scales = self.k_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k_zp = self.k_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k.copy_(quantized_k_val)
narrowed_k_scales.copy_(k_scales)
narrowed_k_zp.copy_(k_zero_points)
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v_scales = self.v_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v_zp = self.v_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v.copy_(quantized_v_val)
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
self.k_cache[:, :, input_pos] = quantized_k_val
self.k_cache_scales[:, :, input_pos] = k_scales
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
self.v_cache[:, :, input_pos] = quantized_v_val
self.v_cache_scales[:, :, input_pos] = v_scales
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
else:
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand All @@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val):
self.cache_fp_type,
)

if self.is_transposed:
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k.copy_(k_val)
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v.copy_(v_val)
else:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)

return k_out, v_out
return k_out.transpose(1, 2), v_out.transpose(1, 2)

@classmethod
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
cache_shape = kv_cache.k_cache.shape
if kv_cache.is_transposed:
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
else:
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
max_seq_length,
n_heads,
head_dim,
cache_type,
kv_cache.is_transposed,
kv_cache.enable_dynamic_shape,
)


Expand Down Expand Up @@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
)
for name, child in module.named_children():
if isinstance(child, KVCache):
if isinstance(child, KVCache) or isinstance(child, CustomKVCache):
setattr(
module,
name,
Expand Down Expand Up @@ -291,11 +231,13 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, S, H, D]
# input_pos: [S], k_val: [B, H, S, D]
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
return self.k_cache, self.v_cache
return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2)


def replace_kv_cache_with_custom_kv_cache(module):
Expand All @@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
if isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
assert (
child.is_transposed is False
), "CustomKVCache does not support transposed cache"
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
setattr(
module,
name,
Expand Down
Loading
Loading