Skip to content

Commit

Permalink
Changes to split kv cache and sdpa
Browse files Browse the repository at this point in the history
Summary:

+ Make all the backend specific kvcache and sdpa implementation abide by
  the new API

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 369434c4d64e6d4500ecfea03b0fd99945b30461
Pull Request resolved: #7413
  • Loading branch information
kimishpatel committed Dec 20, 2024
1 parent f94dda6 commit ff86217
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 171 deletions.
47 changes: 17 additions & 30 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,16 @@ def __init__(
max_seq_length: int,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
self.is_transposed = transpose_cache
if transpose_cache:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.transpose_cache = transpose_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
Expand All @@ -259,12 +253,12 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
# input_pos: [S], k_val: [B, H, S, D]
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
dim_to_slice = 2 if self.transpose_cache else 1
dim_to_slice = 2
seq_length = k_val.size(dim_to_slice)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
Expand All @@ -283,28 +277,22 @@ def update(
else:
k_out = self.k_cache
v_out = self.v_cache
if self.transpose_cache:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class SDPA(nn.Module):
def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
Expand All @@ -314,18 +302,16 @@ def __init__(
def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
bsz,
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
# TODO(kimishpatel): Move this slicing logic to Attention block so that
# SDPA does not have to take input_pos as arg
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
Expand All @@ -336,6 +322,8 @@ def forward(
else:
attn_mask = mask[None, None, input_pos]

# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
# can natively support GQA now. But needs enable_gqa=True
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
Expand Down Expand Up @@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
Expand All @@ -414,15 +400,16 @@ def forward(
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_kv_cache:
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# grouped multiquery attention: expand out keys and values
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down
145 changes: 42 additions & 103 deletions examples/models/llama/source_transformation/quantized_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def __init__(
n_heads,
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
tranposed=False,
enable_dynamic_shape=False,
):
super().__init__()
if cache_type not in (
Expand All @@ -52,14 +50,8 @@ def __init__(
# For now supporting int8 only
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
self.is_transposed = tranposed
self.enable_dynamic_shape = enable_dynamic_shape
if self.is_transposed:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
Expand Down Expand Up @@ -98,71 +90,37 @@ def _quantize(self, value):
return quantized_value, scales, zero_points

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
# quantize current k_val and store it in the cache
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)

quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

if self.is_transposed:
# We cannot use update_cache op at the moment
# if the cache is transposed
# Also note that we shold not need separate paths
# for dynamic shape vs !
# Only reason it is done this way is to accommodate
# for lowering pains of backends that work better
# with index_put op.
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k_scales = self.k_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k_zp = self.k_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k.copy_(quantized_k_val)
narrowed_k_scales.copy_(k_scales)
narrowed_k_zp.copy_(k_zero_points)
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v_scales = self.v_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v_zp = self.v_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v.copy_(quantized_v_val)
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
self.k_cache[:, :, input_pos] = quantized_k_val
self.k_cache_scales[:, :, input_pos] = k_scales
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
self.v_cache[:, :, input_pos] = quantized_v_val
self.v_cache_scales[:, :, input_pos] = v_scales
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
else:
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand All @@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val):
self.cache_fp_type,
)

if self.is_transposed:
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k.copy_(k_val)
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v.copy_(v_val)
else:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)

return k_out, v_out
return k_out.transpose(1, 2), v_out.transpose(1, 2)

@classmethod
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
cache_shape = kv_cache.k_cache.shape
if kv_cache.is_transposed:
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
else:
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
max_seq_length,
n_heads,
head_dim,
cache_type,
kv_cache.is_transposed,
kv_cache.enable_dynamic_shape,
)


Expand Down Expand Up @@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
)
for name, child in module.named_children():
if isinstance(child, KVCache):
if isinstance(child, KVCache) or isinstance(child, CustomKVCache):
setattr(
module,
name,
Expand Down Expand Up @@ -291,11 +231,13 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, S, H, D]
# input_pos: [S], k_val: [B, H, S, D]
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
return self.k_cache, self.v_cache
return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2)


def replace_kv_cache_with_custom_kv_cache(module):
Expand All @@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
if isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
assert (
child.is_transposed is False
), "CustomKVCache does not support transposed cache"
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
setattr(
module,
name,
Expand Down
Loading

0 comments on commit ff86217

Please sign in to comment.