From 5a61919ceeca9f2d6c78e172b92831c60fba8ca9 Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Tue, 22 Oct 2024 20:49:43 +0000 Subject: [PATCH 1/6] qwen cla, fp8 kv init --- .../models/qwen2/modeling_qwen2.py | 283 ++++++++++++++++-- tests/answerdotai/test_cla.py | 107 +++++++ 2 files changed, 357 insertions(+), 33 deletions(-) create mode 100644 tests/answerdotai/test_cla.py diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1941bca17add08..b9d2a2b6119576 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -148,7 +148,7 @@ def _dynamic_frequency_update(self, position_ids, device): self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, x, position_ids, device_type=None, dtype=None): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) @@ -156,7 +156,7 @@ def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type + device_type = x.device.type if x is not None else device_type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -168,7 +168,8 @@ def forward(self, x, position_ids): cos = cos * self.attention_scaling sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + dtype = x.dtype if x is not None else dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -203,7 +204,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None return q_embed, k_embed @@ -234,6 +235,33 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def fp8_quant_dequant(x, scale): + # assert not x.isnan().any(), "key or value states contain NaN before fp8 quantization" + # Get min/max values for float8_e4m3fn + f8_min, f8_max = torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max + # Clamp x/scale to float8 range before conversion + x_fp32 = x.to(torch.float32) / scale + x_fp32_clamped = torch.clamp(x_fp32, f8_min, f8_max) + x_fp8 = x_fp32_clamped.to(torch.float8_e4m3fn) + # assert not x_fp8.isnan().any(), "key or value states contain NaN after fp8 quantization" + x_dequant = (x_fp8.to(torch.float32) * scale).to(x.dtype) + assert not x_dequant.isnan().any(), "key or value states contain NaN after fp8 dequantization" + return x_dequant + + +def compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: + "Returns a dict of decoder layer idxs and whether KV needs to be computed at that layer to be cached." + if cla_kv_cache_map is None: return {} + comput_new_kv_map = {} + is_seen = set() + for k,v in cla_kv_cache_map.items(): + if v not in is_seen: + comput_new_kv_map[k] = True + is_seen.add(v) + else: + comput_new_kv_map[k] = False + return comput_new_kv_map + class Qwen2Attention(nn.Module): """ @@ -268,11 +296,51 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.palu_kv_compression_enabled = config.__dict__.get("palu_kv_compression_enabled", False) + if not self.palu_kv_compression_enabled: + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + # KV fp8 Quantization. + self.use_fp8_kv_scale = config.__dict__.get("use_fp8_kv_scale", None) + + if self.use_fp8_kv_scale: + logger.warning_once("KV fp8 quantization is enabled.") + self.k_scale = torch.nn.Parameter(torch.tensor(0.1)) + self.v_scale = torch.nn.Parameter(torch.tensor(0.1)) + + # Cross Layer Attention (CLA). + # Example of cla_kv_cache_map with 8 layers: + # Index of cached kv is enumerated as it is stored in a list. + # Every new index value corresponds to the layer_idx of the layer where the cached kv was created. + # layer_idx -> cache_idx map: {0:0, 1:1, 2:1, 3:0, 4:2, 5:3, 6:3, 7:2} + # layer 7: oooooo <---| + # layer 6: oooooo <-| | + # layer 5: oooooo --| | + # layer 4: oooooo ----| + # layer 3: oooooo <---| + # layer 2: oooooo <-| | + # layer 1: oooooo --| | + # layer 0: oooooo ----| + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) + if self.cla_kv_cache_map is not None: + logger.warning_once("Cross Layer Attention (CLA) is enabled.") + self.compute_new_kv = compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] + else: + self.compute_new_kv = True + self.debug_kv_sharing = config.__dict__.get("debug_kv_sharing", False) + + # TODO: MLRD PALU. + if self.palu_kv_compression_enabled: + logger.warning_once("MLRD PALU is enabled.") + self.palu_k_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + self.palu_v_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + + self.palu_k_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=True) + self.palu_v_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=True) def forward( self, @@ -280,6 +348,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -288,12 +357,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -301,16 +377,48 @@ def forward( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) + ) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + # 1. PALU KV down projection. + # 2. KV fp8 quantization. + if self.use_fp8_kv_scale and self.compute_new_kv: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + # 3. PALU KV up projection. + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + #### END: KV Compression #### + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if cla_key_value is not None: + if self.compute_new_kv: + # update + cla_key_value.append((key_states, value_states)) + else: + # re-use + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -334,12 +442,13 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class Qwen2FlashAttention2(Qwen2Attention): @@ -366,6 +475,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -374,12 +484,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -388,10 +505,37 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + #### END: KV Compression #### if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute @@ -423,6 +567,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if cla_key_value is not None: + if self.compute_new_kv: + # update + cla_key_value.append((key_states, value_states)) + else: + # re-use + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -479,12 +631,14 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class Qwen2SdpaAttention(Qwen2Attention): @@ -501,6 +655,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -517,6 +672,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, ) @@ -524,12 +680,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -538,15 +701,50 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + #### END: KV Compression #### + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if cla_key_value is not None: + if self.compute_new_kv: + # update + cla_key_value.append((key_states, value_states)) + else: + # re-use + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -577,10 +775,11 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) - + + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, past_key_value, cla_key_value QWEN2_ATTENTION_CLASSES = { @@ -612,6 +811,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -645,11 +845,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, present_key_value, cla_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -670,6 +871,9 @@ def forward( if use_cache: outputs += (present_key_value,) + + if cla_key_value is not None: + outputs += (cla_key_value,) return outputs @@ -819,6 +1023,8 @@ def __init__(self, config: Qwen2Config): self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -899,6 +1105,12 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + # cla + if self.cla_kv_cache_map is not None: + cla_key_value = [] + else: + cla_key_value = None for decoder_layer in self.layers: if output_hidden_states: @@ -911,6 +1123,7 @@ def forward( causal_mask, position_ids, past_key_values, + cla_key_value, output_attentions, use_cache, cache_position, @@ -922,6 +1135,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -935,6 +1149,9 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) + + if cla_key_value is not None: + cla_key_value = layer_outputs[-1] hidden_states = self.norm(hidden_states) diff --git a/tests/answerdotai/test_cla.py b/tests/answerdotai/test_cla.py new file mode 100644 index 00000000000000..3ea261b231efe2 --- /dev/null +++ b/tests/answerdotai/test_cla.py @@ -0,0 +1,107 @@ +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.testing_utils import ( + is_torch_available, + require_torch, +) + + +if is_torch_available(): + import torch + + +@require_torch +class CrossLayerAttentionTest(unittest.TestCase): + + test_parameters = [ + # Model, attention implementation, fp8 kv enabled + ("Qwen/Qwen2.5-32B-Instruct", "eager", True), + ("Qwen/Qwen2.5-32B-Instruct", "sdpa", True), + ("Qwen/Qwen2.5-32B-Instruct", "eager", False), + ("Qwen/Qwen2.5-32B-Instruct", "sdpa", False), + ] + + def test_naive_cla(self): + "Compare a base model without CLA with a model with CLA where each layer computes its own KV cache." + + for model_name, attn_impl, fp8_kv_enabled in self.test_parameters: + with self.subTest(model_id=model_name, attn_impl=attn_impl, fp8_kv_enabled=fp8_kv_enabled): + cfg = AutoConfig.from_pretrained(model_name) + cfg.num_hidden_layers = 4 + cfg.hidden_size //= 8 + cfg.intermediate_size //= 8 + cfg.num_attention_heads //= 2 + cfg.num_key_value_heads //= 2 + cfg._attn_implementation = attn_impl + cfg.use_fp8_kv_scale = fp8_kv_enabled + cfg.palu_kv_compression_enabled = False + cfg.use_cache = False + cfg.debug_kv_sharing = False + cfg.output_attentions = False + + x = torch.arange(32, device="cuda").view(1,-1) + + cfg.cla_kv_cache_map = {0:0, 1:1, 2:2, 3:3} + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16) + test_output = model(x) + + model_state_dict = model.state_dict() + + cfg.cla_kv_cache_map = None + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16) + model.load_state_dict(model_state_dict); + base_output = model(x) + + assert torch.equal(test_output.logits, base_output.logits) + + def test_cla_2(self): + """ + Test CLA with a custom KV cache map. + + LLM with 4 layers: + layer 3: oooooo <---| + layer 2: oooooo <-| | + layer 1: oooooo --| | + layer 0: oooooo ----| + """ + + for model_name, attn_impl, fp8_kv_enabled in self.test_parameters: + with self.subTest(model_id=model_name, attn_impl=attn_impl, fp8_kv_enabled=fp8_kv_enabled): + cfg = AutoConfig.from_pretrained(model_name) + cfg.num_hidden_layers = 4 + cfg.hidden_size //= 8 + cfg.intermediate_size //= 8 + cfg.num_attention_heads //= 2 + cfg.num_key_value_heads //= 2 + cfg._attn_implementation = attn_impl + cfg.use_fp8_kv_scale = fp8_kv_enabled + cfg.cla_kv_cache_map = {0:0, 1:1, 2:1, 3:0} + cfg.palu_kv_compression_enabled = False + cfg.use_cache = False + cfg.debug_kv_sharing = True + cfg.output_attentions = attn_impl == "eager" + + x = torch.arange(32, device="cuda").view(1,-1) + + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16); + + assert model.config.use_fp8_kv_scale == fp8_kv_enabled + assert model.config.cla_kv_cache_map == {0:0, 1:1, 2:1, 3:0} + assert model.config.use_cache == False + + out_eager = model(x) + + if attn_impl == "eager": + assert torch.equal(out_eager.attentions[0], out_eager.attentions[3]) + assert torch.equal(out_eager.attentions[1], out_eager.attentions[2]) + assert not torch.equal(out_eager.attentions[0], out_eager.attentions[1]) + + attn_outputs = [l.self_attn.debug_cla_attn_output for l in model.model.layers] + assert len(attn_outputs) == 4 + assert torch.equal(attn_outputs[0], attn_outputs[3]) + assert torch.equal(attn_outputs[1], attn_outputs[2]) + assert not torch.equal(attn_outputs[0], attn_outputs[1]) \ No newline at end of file From 9748c746f147c91ae78bb56e07216170e7d22e31 Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Thu, 24 Oct 2024 08:14:56 +0000 Subject: [PATCH 2/6] query only rotary emb --- .../models/qwen2/modeling_qwen2.py | 38 +++++++++++++++---- tests/answerdotai/test_cla.py | 4 +- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b9d2a2b6119576..f799f1c4ca08c4 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -204,9 +204,33 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None + k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed +def apply_rotary_pos_emb_query_only(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query tensor. + + Args: + q (`torch.Tensor`): The query tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 class Qwen2MLP(nn.Module): @@ -388,7 +412,7 @@ def forward( if self.compute_new_kv: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: - query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) # 1. PALU KV down projection. # 2. KV fp8 quantization. if self.use_fp8_kv_scale and self.compute_new_kv: @@ -404,7 +428,7 @@ def forward( if self.compute_new_kv: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: - query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) #### END: KV Compression #### if past_key_value is not None: @@ -515,7 +539,7 @@ def forward( if self.compute_new_kv: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: - query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) # 1. PALU KV down projection. # 2. KV fp8 quantization. @@ -534,7 +558,7 @@ def forward( if self.compute_new_kv: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: - query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) #### END: KV Compression #### if past_key_value is not None: @@ -711,7 +735,7 @@ def forward( if self.compute_new_kv: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: - query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) # 1. PALU KV down projection. # 2. KV fp8 quantization. @@ -730,7 +754,7 @@ def forward( if self.compute_new_kv: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: - query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) #### END: KV Compression #### if past_key_value is not None: diff --git a/tests/answerdotai/test_cla.py b/tests/answerdotai/test_cla.py index 3ea261b231efe2..86e30a594ec45d 100644 --- a/tests/answerdotai/test_cla.py +++ b/tests/answerdotai/test_cla.py @@ -52,7 +52,7 @@ def test_naive_cla(self): cfg.cla_kv_cache_map = None model = AutoModelForCausalLM.from_config(cfg) model.to(device="cuda", dtype=torch.bfloat16) - model.load_state_dict(model_state_dict); + model.load_state_dict(model_state_dict) base_output = model(x) assert torch.equal(test_output.logits, base_output.logits) @@ -87,7 +87,7 @@ def test_cla_2(self): x = torch.arange(32, device="cuda").view(1,-1) model = AutoModelForCausalLM.from_config(cfg) - model.to(device="cuda", dtype=torch.bfloat16); + model.to(device="cuda", dtype=torch.bfloat16) assert model.config.use_fp8_kv_scale == fp8_kv_enabled assert model.config.cla_kv_cache_map == {0:0, 1:1, 2:1, 3:0} From c4d6d554592a67272e74164e52bd3e65b9bc7b06 Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Thu, 24 Oct 2024 08:49:20 +0000 Subject: [PATCH 3/6] llama cla --- .../models/llama/modeling_llama.py | 308 ++++++++++++++++-- .../models/qwen2/modeling_qwen2.py | 16 +- tests/answerdotai/test_cla.py | 10 +- 3 files changed, 293 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e9064ff3ae5b22..d9b74606aa2808 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -142,7 +142,7 @@ def _dynamic_frequency_update(self, position_ids, device): self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, x, position_ids, device_type=None, dtype=None): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) @@ -150,7 +150,7 @@ def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type + device_type = x.device.type if x is not None else device_type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -162,7 +162,8 @@ def forward(self, x, position_ids): cos = cos * self.attention_scaling sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + dtype = x.dtype if x is not None else dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -224,6 +225,31 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_pos_emb_query_only(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query tensor. + + Args: + q (`torch.Tensor`): The query tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -270,6 +296,34 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def fp8_quant_dequant(x, scale): + # assert not x.isnan().any(), "key or value states contain NaN before fp8 quantization" + # Get min/max values for float8_e4m3fn + f8_min, f8_max = torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max + # Clamp x/scale to float8 range before conversion + x_fp32 = x.to(torch.float32) / scale + x_fp32_clamped = torch.clamp(x_fp32, f8_min, f8_max) + x_fp8 = x_fp32_clamped.to(torch.float8_e4m3fn) + # assert not x_fp8.isnan().any(), "key or value states contain NaN after fp8 quantization" + x_dequant = (x_fp8.to(torch.float32) * scale).to(x.dtype) + assert not x_dequant.isnan().any(), "key or value states contain NaN after fp8 dequantization" + return x_dequant + + +def compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: + "Returns a dict of decoder layer idxs and whether KV needs to be computed at that layer to be cached." + if cla_kv_cache_map is None: return {} + comput_new_kv_map = {} + is_seen = set() + for k,v in cla_kv_cache_map.items(): + if v not in is_seen: + comput_new_kv_map[k] = True + is_seen.add(v) + else: + comput_new_kv_map[k] = False + return comput_new_kv_map + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -295,12 +349,52 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.palu_kv_compression_enabled = config.__dict__.get("palu_kv_compression_enabled", False) + if not self.palu_kv_compression_enabled: + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + # KV fp8 Quantization. + self.use_fp8_kv_scale = config.__dict__.get("use_fp8_kv_scale", None) + + if self.use_fp8_kv_scale: + logger.warning_once("KV fp8 quantization is enabled.") + self.k_scale = torch.nn.Parameter(torch.tensor(0.1)) + self.v_scale = torch.nn.Parameter(torch.tensor(0.1)) + + # Cross Layer Attention (CLA). + # Example of cla_kv_cache_map with 8 layers: + # Index of cached kv is enumerated as it is stored in a list. + # Every new index value corresponds to the layer_idx of the layer where the cached kv was created. + # layer_idx -> cache_idx map: {0:0, 1:1, 2:1, 3:0, 4:2, 5:3, 6:3, 7:2} + # layer 7: oooooo <---| + # layer 6: oooooo <-| | + # layer 5: oooooo --| | + # layer 4: oooooo ----| + # layer 3: oooooo <---| + # layer 2: oooooo <-| | + # layer 1: oooooo --| | + # layer 0: oooooo ----| + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) + if self.cla_kv_cache_map is not None: + logger.warning_once("Cross Layer Attention (CLA) is enabled.") + self.compute_new_kv = compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] + else: + self.compute_new_kv = True + self.debug_kv_sharing = config.__dict__.get("debug_kv_sharing", False) + + # TODO: MLRD PALU. + if self.palu_kv_compression_enabled: + logger.warning_once("MLRD PALU is enabled.") + self.palu_k_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + self.palu_v_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + + self.palu_k_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.palu_v_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) def forward( self, @@ -308,6 +402,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -326,21 +421,28 @@ def forward( query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) + if self.compute_new_kv: + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -349,15 +451,47 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + # 2. KV fp8 quantization. + if self.use_fp8_kv_scale and self.compute_new_kv: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + # 3. PALU KV up projection. + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if cla_key_value is not None: + if self.compute_new_kv: + # update + cla_key_value.append((key_states, value_states)) + else: + # re-use + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -382,6 +516,8 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output + if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) @@ -392,7 +528,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class LlamaFlashAttention2(LlamaAttention): @@ -416,6 +552,7 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -432,15 +569,22 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -449,16 +593,51 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if cla_key_value is not None: + if self.compute_new_kv: + # update + cla_key_value.append((key_states, value_states)) + else: + # re-use + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -507,12 +686,13 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class LlamaSdpaAttention(LlamaAttention): @@ -529,6 +709,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -546,6 +727,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -555,12 +737,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -569,16 +758,51 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if cla_key_value is not None: + if self.compute_new_kv: + # update + cla_key_value.append((key_states, value_states)) + else: + # re-use + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -608,10 +832,10 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, past_key_value, cla_key_value LLAMA_ATTENTION_CLASSES = { @@ -638,6 +862,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -671,11 +896,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, present_key_value, cla_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -698,6 +924,9 @@ def forward( if use_cache: outputs += (present_key_value,) + if cla_key_value is not None: + outputs += (cla_key_value,) + return outputs @@ -845,6 +1074,8 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False + + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) # Initialize weights and apply final processing self.post_init() @@ -922,6 +1153,12 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + # cla + if self.cla_kv_cache_map is not None: + cla_key_value = [] + else: + cla_key_value = None for decoder_layer in self.layers: if output_hidden_states: @@ -934,6 +1171,7 @@ def forward( causal_mask, position_ids, past_key_values, + cla_key_value, output_attentions, use_cache, cache_position, @@ -945,6 +1183,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -959,6 +1198,9 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) + if cla_key_value is not None: + cla_key_value = layer_outputs[-1] + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f799f1c4ca08c4..879b77ebf73904 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -207,6 +207,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + def apply_rotary_pos_emb_query_only(q, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query tensor. @@ -259,6 +260,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + def fp8_quant_dequant(x, scale): # assert not x.isnan().any(), "key or value states contain NaN before fp8 quantization" # Get min/max values for float8_e4m3fn @@ -391,7 +393,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) else: value_states = None - + if self.debug_kv_sharing: query_states = torch.ones_like(query_states) @@ -401,12 +403,12 @@ def forward( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." - ) + ) device_type, dtype = hidden_states.device.type, hidden_states.dtype cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - + #### BEGIN: KV Compression #### if not self.palu_kv_compression_enabled: if self.compute_new_kv: @@ -442,7 +444,7 @@ def forward( else: # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] - + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -756,7 +758,7 @@ def forward( else: query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) #### END: KV Compression #### - + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -895,7 +897,7 @@ def forward( if use_cache: outputs += (present_key_value,) - + if cla_key_value is not None: outputs += (cla_key_value,) @@ -1173,7 +1175,7 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - + if cla_key_value is not None: cla_key_value = layer_outputs[-1] diff --git a/tests/answerdotai/test_cla.py b/tests/answerdotai/test_cla.py index 86e30a594ec45d..06749f21c578bb 100644 --- a/tests/answerdotai/test_cla.py +++ b/tests/answerdotai/test_cla.py @@ -18,8 +18,16 @@ class CrossLayerAttentionTest(unittest.TestCase): # Model, attention implementation, fp8 kv enabled ("Qwen/Qwen2.5-32B-Instruct", "eager", True), ("Qwen/Qwen2.5-32B-Instruct", "sdpa", True), + ("Qwen/Qwen2.5-32B-Instruct", "flash_attention_2", True), ("Qwen/Qwen2.5-32B-Instruct", "eager", False), - ("Qwen/Qwen2.5-32B-Instruct", "sdpa", False), + ("Qwen/Qwen2.5-32B-Instruct", "sdpa", False), + ("Qwen/Qwen2.5-32B-Instruct", "flash_attention_2", False), + ("meta-llama/Llama-3.1-8B-Instruct", "eager", True), + ("meta-llama/Llama-3.1-8B-Instruct", "sdpa", True), + ("meta-llama/Llama-3.1-8B-Instruct", "flash_attention_2", True), + ("meta-llama/Llama-3.1-8B-Instruct", "eager", False), + ("meta-llama/Llama-3.1-8B-Instruct", "sdpa", False), + ("meta-llama/Llama-3.1-8B-Instruct", "flash_attention_2", False), ] def test_naive_cla(self): From ca2659613151314dad83a3715d93435dedb581fd Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 25 Oct 2024 20:22:54 +0000 Subject: [PATCH 4/6] scalar not supported by FSDP, add CLA training test --- .../models/llama/modeling_llama.py | 4 +- .../models/qwen2/modeling_qwen2.py | 4 +- tests/answerdotai/test_cla.py | 52 ++++++++++++++++++- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d9b74606aa2808..297a861d99ffcc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -363,8 +363,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): if self.use_fp8_kv_scale: logger.warning_once("KV fp8 quantization is enabled.") - self.k_scale = torch.nn.Parameter(torch.tensor(0.1)) - self.v_scale = torch.nn.Parameter(torch.tensor(0.1)) + self.k_scale = torch.nn.Parameter(torch.tensor([0.1])) + self.v_scale = torch.nn.Parameter(torch.tensor([0.1])) # Cross Layer Attention (CLA). # Example of cla_kv_cache_map with 8 layers: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 879b77ebf73904..fa43569d7ede2e 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -335,8 +335,8 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): if self.use_fp8_kv_scale: logger.warning_once("KV fp8 quantization is enabled.") - self.k_scale = torch.nn.Parameter(torch.tensor(0.1)) - self.v_scale = torch.nn.Parameter(torch.tensor(0.1)) + self.k_scale = torch.nn.Parameter(torch.tensor([0.1])) + self.v_scale = torch.nn.Parameter(torch.tensor([0.1])) # Cross Layer Attention (CLA). # Example of cla_kv_cache_map with 8 layers: diff --git a/tests/answerdotai/test_cla.py b/tests/answerdotai/test_cla.py index 06749f21c578bb..a30c908ed5996d 100644 --- a/tests/answerdotai/test_cla.py +++ b/tests/answerdotai/test_cla.py @@ -112,4 +112,54 @@ def test_cla_2(self): assert len(attn_outputs) == 4 assert torch.equal(attn_outputs[0], attn_outputs[3]) assert torch.equal(attn_outputs[1], attn_outputs[2]) - assert not torch.equal(attn_outputs[0], attn_outputs[1]) \ No newline at end of file + assert not torch.equal(attn_outputs[0], attn_outputs[1]) + + def test_training_with_cla(self): + """ + Test that parameters are updated correctly: + + layer 0 k_proj, v_proj, k_scale, v_scale should be updated + layer 1 k_proj, v_proj shouldn't be updated + """ + model_name = "Qwen/Qwen2.5-32B-Instruct" + attn_impl = "eager" + fp8_kv_enabled = True + + cfg = AutoConfig.from_pretrained(model_name) + cfg.num_hidden_layers = 2 + cfg.hidden_size //= 8 + cfg.intermediate_size //= 8 + cfg.num_attention_heads //= 2 + cfg.num_key_value_heads //= 2 + cfg._attn_implementation = attn_impl + cfg.use_fp8_kv_scale = fp8_kv_enabled + cfg.cla_kv_cache_map = {0: 0, 1: 0} + cfg.palu_kv_compression_enabled = False + cfg.use_cache = False + cfg.debug_kv_sharing = True + + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16) + + model.train() + + # Perform a forward and backward pass + x = torch.arange(32, device="cuda").view(1, -1) + output = model(x) + loss = output.logits.sum() + loss.backward() + + # Check which parameters have been updated + for name, param in model.named_parameters(): + if param.grad is not None: + updated = torch.any(param.grad != 0) + if "layers.0" in name: + if any(proj in name for proj in ["k_proj", "v_proj"]): + self.assertTrue(updated, f"{name} should be updated") + if fp8_kv_enabled and any(scale in name for scale in ["k_scale", "v_scale"]): + self.assertTrue(updated, f"{name} should be updated") + elif "layers.1" in name: + if any(proj in name for proj in ["k_proj", "v_proj"]): + self.assertFalse(updated, f"{name} should not be updated") + else: + self.assertTrue(updated, f"{name} should be updated") \ No newline at end of file From b1af4ca3dfad50e26530b76d6243cae5dd6e33c7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 25 Oct 2024 22:03:03 +0000 Subject: [PATCH 5/6] detach shared kv by default --- src/transformers/models/llama/modeling_llama.py | 16 +++++++++++++--- src/transformers/models/qwen2/modeling_qwen2.py | 16 +++++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 297a861d99ffcc..d73e5f9e60278d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -385,6 +385,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.compute_new_kv = compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] else: self.compute_new_kv = True + self.cla_kv_detached = config.__dict__.get("cla_kv_detached", True) self.debug_kv_sharing = config.__dict__.get("debug_kv_sharing", False) # TODO: MLRD PALU. @@ -488,7 +489,10 @@ def forward( if cla_key_value is not None: if self.compute_new_kv: # update - cla_key_value.append((key_states, value_states)) + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) else: # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] @@ -633,7 +637,10 @@ def forward( if cla_key_value is not None: if self.compute_new_kv: # update - cla_key_value.append((key_states, value_states)) + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) else: # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] @@ -798,7 +805,10 @@ def forward( if cla_key_value is not None: if self.compute_new_kv: # update - cla_key_value.append((key_states, value_states)) + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) else: # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index fa43569d7ede2e..528400b130d03e 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -357,6 +357,7 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): self.compute_new_kv = compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] else: self.compute_new_kv = True + self.cla_kv_detached = config.__dict__.get("cla_kv_detached", True) self.debug_kv_sharing = config.__dict__.get("debug_kv_sharing", False) # TODO: MLRD PALU. @@ -440,7 +441,10 @@ def forward( if cla_key_value is not None: if self.compute_new_kv: # update - cla_key_value.append((key_states, value_states)) + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) else: # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] @@ -596,7 +600,10 @@ def forward( if cla_key_value is not None: if self.compute_new_kv: # update - cla_key_value.append((key_states, value_states)) + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) else: # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] @@ -766,7 +773,10 @@ def forward( if cla_key_value is not None: if self.compute_new_kv: # update - cla_key_value.append((key_states, value_states)) + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) else: # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] From c4fd5cd49f1f9bcafd25b2bd98d7af07ad1c7ac0 Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Mon, 28 Oct 2024 15:46:40 +0000 Subject: [PATCH 6/6] gradual unfreeze --- .../models/llama/modeling_llama.py | 31 +++++++++---------- .../models/qwen2/modeling_qwen2.py | 31 +++++++++---------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d73e5f9e60278d..36770200a370ac 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -310,18 +310,20 @@ def fp8_quant_dequant(x, scale): return x_dequant -def compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: +def create_compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: "Returns a dict of decoder layer idxs and whether KV needs to be computed at that layer to be cached." if cla_kv_cache_map is None: return {} - comput_new_kv_map = {} + compute_new_kv_map = {} is_seen = set() for k,v in cla_kv_cache_map.items(): - if v not in is_seen: - comput_new_kv_map[k] = True + if v == -1: + compute_new_kv_map[k] = True + elif v not in is_seen: + compute_new_kv_map[k] = True is_seen.add(v) else: - comput_new_kv_map[k] = False - return comput_new_kv_map + compute_new_kv_map[k] = False + return compute_new_kv_map class LlamaAttention(nn.Module): @@ -382,7 +384,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) if self.cla_kv_cache_map is not None: logger.warning_once("Cross Layer Attention (CLA) is enabled.") - self.compute_new_kv = compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] + self.compute_new_kv = create_compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] else: self.compute_new_kv = True self.cla_kv_detached = config.__dict__.get("cla_kv_detached", True) @@ -486,15 +488,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cla_key_value is not None: + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: if self.compute_new_kv: - # update if self.cla_kv_detached: cla_key_value.append((key_states.detach(), value_states.detach())) else: cla_key_value.append((key_states, value_states)) else: - # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -634,15 +635,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cla_key_value is not None: + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: if self.compute_new_kv: - # update if self.cla_kv_detached: cla_key_value.append((key_states.detach(), value_states.detach())) else: cla_key_value.append((key_states, value_states)) else: - # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -802,15 +802,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cla_key_value is not None: + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: if self.compute_new_kv: - # update if self.cla_kv_detached: cla_key_value.append((key_states.detach(), value_states.detach())) else: cla_key_value.append((key_states, value_states)) else: - # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] key_states = repeat_kv(key_states, self.num_key_value_groups) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 528400b130d03e..b4a89f08682b06 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -275,18 +275,20 @@ def fp8_quant_dequant(x, scale): return x_dequant -def compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: +def create_compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: "Returns a dict of decoder layer idxs and whether KV needs to be computed at that layer to be cached." if cla_kv_cache_map is None: return {} - comput_new_kv_map = {} + compute_new_kv_map = {} is_seen = set() for k,v in cla_kv_cache_map.items(): - if v not in is_seen: - comput_new_kv_map[k] = True + if v == -1: + compute_new_kv_map[k] = True + elif v not in is_seen: + compute_new_kv_map[k] = True is_seen.add(v) else: - comput_new_kv_map[k] = False - return comput_new_kv_map + compute_new_kv_map[k] = False + return compute_new_kv_map class Qwen2Attention(nn.Module): @@ -354,7 +356,7 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) if self.cla_kv_cache_map is not None: logger.warning_once("Cross Layer Attention (CLA) is enabled.") - self.compute_new_kv = compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] + self.compute_new_kv = create_compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] else: self.compute_new_kv = True self.cla_kv_detached = config.__dict__.get("cla_kv_detached", True) @@ -438,15 +440,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cla_key_value is not None: + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: if self.compute_new_kv: - # update if self.cla_kv_detached: cla_key_value.append((key_states.detach(), value_states.detach())) else: cla_key_value.append((key_states, value_states)) else: - # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] # repeat k/v heads if n_kv_heads < n_heads @@ -597,15 +598,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cla_key_value is not None: + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: if self.compute_new_kv: - # update if self.cla_kv_detached: cla_key_value.append((key_states.detach(), value_states.detach())) else: cla_key_value.append((key_states, value_states)) else: - # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] # repeat k/v heads if n_kv_heads < n_heads @@ -770,15 +770,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cla_key_value is not None: + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: if self.compute_new_kv: - # update if self.cla_kv_detached: cla_key_value.append((key_states.detach(), value_states.detach())) else: cla_key_value.append((key_states, value_states)) else: - # re-use key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] key_states = repeat_kv(key_states, self.num_key_value_groups)