diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index aaef3cd980..176b597a94 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -232,22 +232,16 @@ def __init__( max_seq_length: int, n_heads: int, head_dim: int, - transpose_cache: bool, enable_dynamic_shape: bool, dtype=torch.float32, ): super().__init__() self.max_seq_length = max_seq_length - self.is_transposed = transpose_cache - if transpose_cache: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads self.head_dim = head_dim - self.transpose_cache = transpose_cache self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") @@ -259,12 +253,12 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache + # input_pos: [S], k_val: [B, H, S, D] if self.enable_dynamic_shape: start_pos = input_pos[0].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_seq_length) - dim_to_slice = 2 if self.transpose_cache else 1 + dim_to_slice = 2 seq_length = k_val.size(dim_to_slice) # Replace the entry in the cache for this token # The following lines are equivalent to: @@ -283,12 +277,8 @@ def update( else: k_out = self.k_cache v_out = self.v_cache - if self.transpose_cache: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val return k_out, v_out @@ -296,7 +286,6 @@ def update( class SDPA(nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, @@ -304,7 +293,6 @@ def __init__( enable_dynamic_shape: bool, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -314,18 +302,16 @@ def __init__( def forward( self, input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) + q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) + v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) bsz, seqlen, mask: torch.Tensor, ) -> torch.Tensor: - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - k, v = self.kv_cache.update(input_pos, k, v) + # TODO(kimishpatel): Move this slicing logic to Attention block so that + # SDPA does not have to take input_pos as arg if self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) @@ -336,6 +322,8 @@ def forward( else: attn_mask = mask[None, None, input_pos] + # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention + # can natively support GQA now. But needs enable_gqa=True k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) @@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( - kv_cache=self.kv_cache, dim=self.n_local_heads * self.head_dim, head_dim=self.head_dim, n_rep=self.n_rep, @@ -414,15 +400,16 @@ def forward( # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if self.use_kv_cache: assert input_pos is not None + k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) return self.wo(output) - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - # grouped multiquery attention: expand out keys and values k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index d8ac99656f..fa0b3f9251 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -37,8 +37,6 @@ def __init__( n_heads, head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, - tranposed=False, - enable_dynamic_shape=False, ): super().__init__() if cache_type not in ( @@ -52,14 +50,8 @@ def __init__( # For now supporting int8 only self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 - self.is_transposed = tranposed - self.enable_dynamic_shape = enable_dynamic_shape - if self.is_transposed: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - scale_shape = (max_batch_size, n_heads, max_seq_length, 1) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - scale_shape = (max_batch_size, max_seq_length, n_heads, 1) + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + scale_shape = (max_batch_size, max_seq_length, n_heads, 1) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) @@ -98,71 +90,37 @@ def _quantize(self, value): return quantized_value, scales, zero_points def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) # quantize current k_val and store it in the cache quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) - if self.is_transposed: - # We cannot use update_cache op at the moment - # if the cache is transposed - # Also note that we shold not need separate paths - # for dynamic shape vs ! - # Only reason it is done this way is to accommodate - # for lowering pains of backends that work better - # with index_put op. - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k_scales = self.k_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k_zp = self.k_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k.copy_(quantized_k_val) - narrowed_k_scales.copy_(k_scales) - narrowed_k_zp.copy_(k_zero_points) - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v_scales = self.v_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v_zp = self.v_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v.copy_(quantized_v_val) - narrowed_v_scales.copy_(v_scales) - narrowed_v_zp.copy_(v_zero_points) - else: - self.k_cache[:, :, input_pos] = quantized_k_val - self.k_cache_scales[:, :, input_pos] = k_scales - self.k_cache_zero_points[:, :, input_pos] = k_zero_points - self.v_cache[:, :, input_pos] = quantized_v_val - self.v_cache_scales[:, :, input_pos] = v_scales - self.v_cache_zero_points[:, :, input_pos] = v_zero_points - else: - # Right now using custom ops on this path. - # In future we can update custom op to handle transposed cache - # as well. - # Note that we may have to revert this change if other ET - # backends such as QNN want to use quantized cache, with dynamic shape, - # instead of quantizing on their own. - # But until this opting for code simplicity - start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - k_zero_points, self.k_cache_zero_points, start_pos - ) - _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) - _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - v_zero_points, self.v_cache_zero_points, start_pos - ) + # Right now using custom ops on this path. + # In future we can update custom op to handle transposed cache + # as well. + # Note that we may have to revert this change if other ET + # backends such as QNN want to use quantized cache, with dynamic shape, + # instead of quantizing on their own. + # But until this opting for code simplicity + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) + _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( + k_zero_points, self.k_cache_zero_points, start_pos + ) + _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) + _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( + v_zero_points, self.v_cache_zero_points, start_pos + ) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, @@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val): self.cache_fp_type, ) - if self.is_transposed: - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k.copy_(k_val) - narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v.copy_(v_val) - else: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) - _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) - return k_out, v_out + return k_out.transpose(1, 2), v_out.transpose(1, 2) @classmethod def from_float(cls, kv_cache, cache_type: QuantizedCacheType): - cache_shape = kv_cache.k_cache.shape - if kv_cache.is_transposed: - max_batch_size, n_heads, max_seq_length, head_dim = cache_shape - else: - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape + if isinstance(kv_cache, CustomKVCache): + # If replacing custom kv cache, then the shape is [B, S, H, D] + max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape return cls( max_batch_size, max_seq_length, n_heads, head_dim, cache_type, - kv_cache.is_transposed, - kv_cache.enable_dynamic_shape, ) @@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module): "Replacing KVCache with QuantizedKVCache. This modifies the model in place." ) for name, child in module.named_children(): - if isinstance(child, KVCache): + if isinstance(child, KVCache) or isinstance(child, CustomKVCache): setattr( module, name, @@ -291,11 +231,13 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, S, H, D] + # input_pos: [S], k_val: [B, H, S, D] + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) start_pos = input_pos[0].item() _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) - return self.k_cache, self.v_cache + return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2) def replace_kv_cache_with_custom_kv_cache(module): @@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module): if isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype - assert ( - child.is_transposed is False - ), "CustomKVCache does not support transposed cache" - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape setattr( module, name, diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 4d4b3bf7f5..f68e43cbcd 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,19 +22,9 @@ class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: Union[KVCache, QuantizedKVCache], dim: int, ): super().__init__() - # Custom op only supports float32 currently. Converting to/from float32 is - # faster than not having the op. - self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): - self.kv_cache = kv_cache.to(torch.float) - else: - assert ( - kv_cache.cache_fp_type == torch.float32 - ), "Only float32 is supported for custom SDPA" self.dim = dim def forward( @@ -47,6 +37,10 @@ def forward( seqlen, mask, ): + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. input_dtype = q.dtype @@ -54,13 +48,10 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - k_cache = self.kv_cache.k_cache - v_cache = self.kv_cache.v_cache - k_cache, v_cache = self.kv_cache.update(input_pos, k, v) output = torch.ops.llama.custom_sdpa( q, - k_cache, - v_cache, + k, + v, input_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code @@ -75,7 +66,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): setattr( module, name, - SDPACustom(child.kv_cache, child.dim), + SDPACustom(child.dim), ) else: _replace_sdpa_with_custom_op(child) @@ -91,13 +82,11 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: class SDPASimple(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -112,11 +101,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) @@ -150,12 +134,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class SDPAFlex(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.n_rep = n_rep @@ -169,9 +151,10 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - k, v = self.kv_cache.update(input_pos, k, v) + """ + q: (bs, n_heads, seqlen, head_dim) + k, v: (bs, n_local_heads, seqlen, head_dim) + """ k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) attn_mask = mask[input_pos] @@ -191,7 +174,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): setattr( module, name, - SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPASimple(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_simple_sdpa(child) @@ -204,7 +187,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): setattr( module, name, - SDPAFlex(child.kv_cache, child.dim, child.n_rep), + SDPAFlex(child.dim, child.n_rep), ) else: replace_sdpa_with_flex_sdpa(child) @@ -236,13 +219,11 @@ class SDPACoreML(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -257,11 +238,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] if self.n_rep > 1: @@ -279,7 +255,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module): setattr( module, name, - SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPACoreML(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_coreml_sdpa(child) @@ -366,6 +342,9 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + # can we combine this with KVCacheCoreML? + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val) v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val)