diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e9064ff3ae5b22..36770200a370ac 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -142,7 +142,7 @@ def _dynamic_frequency_update(self, position_ids, device): self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, x, position_ids, device_type=None, dtype=None): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) @@ -150,7 +150,7 @@ def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type + device_type = x.device.type if x is not None else device_type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -162,7 +162,8 @@ def forward(self, x, position_ids): cos = cos * self.attention_scaling sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + dtype = x.dtype if x is not None else dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -224,6 +225,31 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_pos_emb_query_only(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query tensor. + + Args: + q (`torch.Tensor`): The query tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -270,6 +296,36 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def fp8_quant_dequant(x, scale): + # assert not x.isnan().any(), "key or value states contain NaN before fp8 quantization" + # Get min/max values for float8_e4m3fn + f8_min, f8_max = torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max + # Clamp x/scale to float8 range before conversion + x_fp32 = x.to(torch.float32) / scale + x_fp32_clamped = torch.clamp(x_fp32, f8_min, f8_max) + x_fp8 = x_fp32_clamped.to(torch.float8_e4m3fn) + # assert not x_fp8.isnan().any(), "key or value states contain NaN after fp8 quantization" + x_dequant = (x_fp8.to(torch.float32) * scale).to(x.dtype) + assert not x_dequant.isnan().any(), "key or value states contain NaN after fp8 dequantization" + return x_dequant + + +def create_compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: + "Returns a dict of decoder layer idxs and whether KV needs to be computed at that layer to be cached." + if cla_kv_cache_map is None: return {} + compute_new_kv_map = {} + is_seen = set() + for k,v in cla_kv_cache_map.items(): + if v == -1: + compute_new_kv_map[k] = True + elif v not in is_seen: + compute_new_kv_map[k] = True + is_seen.add(v) + else: + compute_new_kv_map[k] = False + return compute_new_kv_map + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -295,12 +351,53 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.palu_kv_compression_enabled = config.__dict__.get("palu_kv_compression_enabled", False) + if not self.palu_kv_compression_enabled: + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + # KV fp8 Quantization. + self.use_fp8_kv_scale = config.__dict__.get("use_fp8_kv_scale", None) + + if self.use_fp8_kv_scale: + logger.warning_once("KV fp8 quantization is enabled.") + self.k_scale = torch.nn.Parameter(torch.tensor([0.1])) + self.v_scale = torch.nn.Parameter(torch.tensor([0.1])) + + # Cross Layer Attention (CLA). + # Example of cla_kv_cache_map with 8 layers: + # Index of cached kv is enumerated as it is stored in a list. + # Every new index value corresponds to the layer_idx of the layer where the cached kv was created. + # layer_idx -> cache_idx map: {0:0, 1:1, 2:1, 3:0, 4:2, 5:3, 6:3, 7:2} + # layer 7: oooooo <---| + # layer 6: oooooo <-| | + # layer 5: oooooo --| | + # layer 4: oooooo ----| + # layer 3: oooooo <---| + # layer 2: oooooo <-| | + # layer 1: oooooo --| | + # layer 0: oooooo ----| + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) + if self.cla_kv_cache_map is not None: + logger.warning_once("Cross Layer Attention (CLA) is enabled.") + self.compute_new_kv = create_compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] + else: + self.compute_new_kv = True + self.cla_kv_detached = config.__dict__.get("cla_kv_detached", True) + self.debug_kv_sharing = config.__dict__.get("debug_kv_sharing", False) + + # TODO: MLRD PALU. + if self.palu_kv_compression_enabled: + logger.warning_once("MLRD PALU is enabled.") + self.palu_k_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + self.palu_v_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + + self.palu_k_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.palu_v_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) def forward( self, @@ -308,6 +405,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -326,21 +424,28 @@ def forward( query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) + if self.compute_new_kv: + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -349,15 +454,49 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + # 2. KV fp8 quantization. + if self.use_fp8_kv_scale and self.compute_new_kv: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + # 3. PALU KV up projection. + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: + if self.compute_new_kv: + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) + else: + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -382,6 +521,8 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output + if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) @@ -392,7 +533,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class LlamaFlashAttention2(LlamaAttention): @@ -416,6 +557,7 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -432,15 +574,22 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -449,16 +598,53 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: + if self.compute_new_kv: + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) + else: + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -507,12 +693,13 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class LlamaSdpaAttention(LlamaAttention): @@ -529,6 +716,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -546,6 +734,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -555,12 +744,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -569,16 +765,53 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: + if self.compute_new_kv: + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) + else: + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -608,10 +841,10 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, past_key_value, cla_key_value LLAMA_ATTENTION_CLASSES = { @@ -638,6 +871,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -671,11 +905,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, present_key_value, cla_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -698,6 +933,9 @@ def forward( if use_cache: outputs += (present_key_value,) + if cla_key_value is not None: + outputs += (cla_key_value,) + return outputs @@ -845,6 +1083,8 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False + + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) # Initialize weights and apply final processing self.post_init() @@ -922,6 +1162,12 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + # cla + if self.cla_kv_cache_map is not None: + cla_key_value = [] + else: + cla_key_value = None for decoder_layer in self.layers: if output_hidden_states: @@ -934,6 +1180,7 @@ def forward( causal_mask, position_ids, past_key_values, + cla_key_value, output_attentions, use_cache, cache_position, @@ -945,6 +1192,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -959,6 +1207,9 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) + if cla_key_value is not None: + cla_key_value = layer_outputs[-1] + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1941bca17add08..b4a89f08682b06 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -148,7 +148,7 @@ def _dynamic_frequency_update(self, position_ids, device): self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, x, position_ids, device_type=None, dtype=None): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) @@ -156,7 +156,7 @@ def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type + device_type = x.device.type if x is not None else device_type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -168,7 +168,8 @@ def forward(self, x, position_ids): cos = cos * self.attention_scaling sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + dtype = x.dtype if x is not None else dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -207,6 +208,31 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_pos_emb_query_only(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query tensor. + + Args: + q (`torch.Tensor`): The query tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 class Qwen2MLP(nn.Module): def __init__(self, config): @@ -235,6 +261,36 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def fp8_quant_dequant(x, scale): + # assert not x.isnan().any(), "key or value states contain NaN before fp8 quantization" + # Get min/max values for float8_e4m3fn + f8_min, f8_max = torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max + # Clamp x/scale to float8 range before conversion + x_fp32 = x.to(torch.float32) / scale + x_fp32_clamped = torch.clamp(x_fp32, f8_min, f8_max) + x_fp8 = x_fp32_clamped.to(torch.float8_e4m3fn) + # assert not x_fp8.isnan().any(), "key or value states contain NaN after fp8 quantization" + x_dequant = (x_fp8.to(torch.float32) * scale).to(x.dtype) + assert not x_dequant.isnan().any(), "key or value states contain NaN after fp8 dequantization" + return x_dequant + + +def create_compute_new_kv_map(cla_kv_cache_map) -> dict[int, bool]: + "Returns a dict of decoder layer idxs and whether KV needs to be computed at that layer to be cached." + if cla_kv_cache_map is None: return {} + compute_new_kv_map = {} + is_seen = set() + for k,v in cla_kv_cache_map.items(): + if v == -1: + compute_new_kv_map[k] = True + elif v not in is_seen: + compute_new_kv_map[k] = True + is_seen.add(v) + else: + compute_new_kv_map[k] = False + return compute_new_kv_map + + class Qwen2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -268,11 +324,52 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.palu_kv_compression_enabled = config.__dict__.get("palu_kv_compression_enabled", False) + if not self.palu_kv_compression_enabled: + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + # KV fp8 Quantization. + self.use_fp8_kv_scale = config.__dict__.get("use_fp8_kv_scale", None) + + if self.use_fp8_kv_scale: + logger.warning_once("KV fp8 quantization is enabled.") + self.k_scale = torch.nn.Parameter(torch.tensor([0.1])) + self.v_scale = torch.nn.Parameter(torch.tensor([0.1])) + + # Cross Layer Attention (CLA). + # Example of cla_kv_cache_map with 8 layers: + # Index of cached kv is enumerated as it is stored in a list. + # Every new index value corresponds to the layer_idx of the layer where the cached kv was created. + # layer_idx -> cache_idx map: {0:0, 1:1, 2:1, 3:0, 4:2, 5:3, 6:3, 7:2} + # layer 7: oooooo <---| + # layer 6: oooooo <-| | + # layer 5: oooooo --| | + # layer 4: oooooo ----| + # layer 3: oooooo <---| + # layer 2: oooooo <-| | + # layer 1: oooooo --| | + # layer 0: oooooo ----| + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) + if self.cla_kv_cache_map is not None: + logger.warning_once("Cross Layer Attention (CLA) is enabled.") + self.compute_new_kv = create_compute_new_kv_map(self.cla_kv_cache_map)[self.layer_idx] + else: + self.compute_new_kv = True + self.cla_kv_detached = config.__dict__.get("cla_kv_detached", True) + self.debug_kv_sharing = config.__dict__.get("debug_kv_sharing", False) + + # TODO: MLRD PALU. + if self.palu_kv_compression_enabled: + logger.warning_once("MLRD PALU is enabled.") + self.palu_k_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + self.palu_v_down_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * config.palu_head_dim, bias=False) + + self.palu_k_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=True) + self.palu_v_up_proj = nn.Linear(self.num_key_value_heads * config.palu_head_dim, self.num_key_value_heads * self.head_dim, bias=True) def forward( self, @@ -280,6 +377,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -288,12 +386,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -302,14 +407,48 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + # 2. KV fp8 quantization. + if self.use_fp8_kv_scale and self.compute_new_kv: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + # 3. PALU KV up projection. + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: + if self.compute_new_kv: + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) + else: + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -334,12 +473,13 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class Qwen2FlashAttention2(Qwen2Attention): @@ -366,6 +506,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -374,12 +515,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -388,10 +536,37 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute @@ -423,6 +598,16 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: + if self.compute_new_kv: + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) + else: + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -479,12 +664,14 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value, cla_key_value class Qwen2SdpaAttention(Qwen2Attention): @@ -501,6 +688,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -517,6 +705,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, ) @@ -524,12 +713,19 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.compute_new_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.compute_new_kv: + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + else: + value_states = None + + if self.debug_kv_sharing: + query_states = torch.ones_like(query_states) if position_embeddings is None: logger.warning_once( @@ -538,15 +734,52 @@ def forward( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) - cos, sin = self.rotary_emb(value_states, position_ids) + device_type, dtype = hidden_states.device.type, hidden_states.dtype + cos, sin = self.rotary_emb(value_states, position_ids, device_type, dtype) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + #### BEGIN: KV Compression #### + if not self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + # 1. PALU KV down projection. + + # 2. KV fp8 quantization. + if self.compute_new_kv and self.use_fp8_kv_scale: + key_states = fp8_quant_dequant(key_states, self.k_scale) + value_states = fp8_quant_dequant(value_states, self.v_scale) + + # 3. PALU KV up projection. + + # 4. PALU ROPE. + # If PALU is enabled, key rotary embeddings are applied after the up projection. + # This is because palu key down projection is fused to k_proj at inference time. + # Order at inference will be down projection -> fp8 quantization -> write cache. + # read cache -> fp8 dequantization -> up projection -> rope. + if self.palu_kv_compression_enabled: + if self.compute_new_kv: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states = apply_rotary_pos_emb_query_only(query_states, cos, sin) + #### END: KV Compression #### if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # update or re-use kv. + if cla_key_value is not None and self.cla_kv_cache_map[self.layer_idx] != -1: + if self.compute_new_kv: + if self.cla_kv_detached: + cla_key_value.append((key_states.detach(), value_states.detach())) + else: + cla_key_value.append((key_states, value_states)) + else: + key_states, value_states = cla_key_value[self.cla_kv_cache_map[self.layer_idx]] + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -577,10 +810,11 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) - + + if self.debug_kv_sharing: self.debug_cla_attn_output = attn_output attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, past_key_value, cla_key_value QWEN2_ATTENTION_CLASSES = { @@ -612,6 +846,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + cla_key_value: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -645,11 +880,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, present_key_value, cla_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -671,6 +907,9 @@ def forward( if use_cache: outputs += (present_key_value,) + if cla_key_value is not None: + outputs += (cla_key_value,) + return outputs @@ -819,6 +1058,8 @@ def __init__(self, config: Qwen2Config): self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.cla_kv_cache_map = config.__dict__.get("cla_kv_cache_map", None) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -899,6 +1140,12 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + # cla + if self.cla_kv_cache_map is not None: + cla_key_value = [] + else: + cla_key_value = None for decoder_layer in self.layers: if output_hidden_states: @@ -911,6 +1158,7 @@ def forward( causal_mask, position_ids, past_key_values, + cla_key_value, output_attentions, use_cache, cache_position, @@ -922,6 +1170,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + cla_key_value=cla_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -936,6 +1185,9 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) + if cla_key_value is not None: + cla_key_value = layer_outputs[-1] + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer diff --git a/tests/answerdotai/test_cla.py b/tests/answerdotai/test_cla.py new file mode 100644 index 00000000000000..a30c908ed5996d --- /dev/null +++ b/tests/answerdotai/test_cla.py @@ -0,0 +1,165 @@ +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.testing_utils import ( + is_torch_available, + require_torch, +) + + +if is_torch_available(): + import torch + + +@require_torch +class CrossLayerAttentionTest(unittest.TestCase): + + test_parameters = [ + # Model, attention implementation, fp8 kv enabled + ("Qwen/Qwen2.5-32B-Instruct", "eager", True), + ("Qwen/Qwen2.5-32B-Instruct", "sdpa", True), + ("Qwen/Qwen2.5-32B-Instruct", "flash_attention_2", True), + ("Qwen/Qwen2.5-32B-Instruct", "eager", False), + ("Qwen/Qwen2.5-32B-Instruct", "sdpa", False), + ("Qwen/Qwen2.5-32B-Instruct", "flash_attention_2", False), + ("meta-llama/Llama-3.1-8B-Instruct", "eager", True), + ("meta-llama/Llama-3.1-8B-Instruct", "sdpa", True), + ("meta-llama/Llama-3.1-8B-Instruct", "flash_attention_2", True), + ("meta-llama/Llama-3.1-8B-Instruct", "eager", False), + ("meta-llama/Llama-3.1-8B-Instruct", "sdpa", False), + ("meta-llama/Llama-3.1-8B-Instruct", "flash_attention_2", False), + ] + + def test_naive_cla(self): + "Compare a base model without CLA with a model with CLA where each layer computes its own KV cache." + + for model_name, attn_impl, fp8_kv_enabled in self.test_parameters: + with self.subTest(model_id=model_name, attn_impl=attn_impl, fp8_kv_enabled=fp8_kv_enabled): + cfg = AutoConfig.from_pretrained(model_name) + cfg.num_hidden_layers = 4 + cfg.hidden_size //= 8 + cfg.intermediate_size //= 8 + cfg.num_attention_heads //= 2 + cfg.num_key_value_heads //= 2 + cfg._attn_implementation = attn_impl + cfg.use_fp8_kv_scale = fp8_kv_enabled + cfg.palu_kv_compression_enabled = False + cfg.use_cache = False + cfg.debug_kv_sharing = False + cfg.output_attentions = False + + x = torch.arange(32, device="cuda").view(1,-1) + + cfg.cla_kv_cache_map = {0:0, 1:1, 2:2, 3:3} + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16) + test_output = model(x) + + model_state_dict = model.state_dict() + + cfg.cla_kv_cache_map = None + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16) + model.load_state_dict(model_state_dict) + base_output = model(x) + + assert torch.equal(test_output.logits, base_output.logits) + + def test_cla_2(self): + """ + Test CLA with a custom KV cache map. + + LLM with 4 layers: + layer 3: oooooo <---| + layer 2: oooooo <-| | + layer 1: oooooo --| | + layer 0: oooooo ----| + """ + + for model_name, attn_impl, fp8_kv_enabled in self.test_parameters: + with self.subTest(model_id=model_name, attn_impl=attn_impl, fp8_kv_enabled=fp8_kv_enabled): + cfg = AutoConfig.from_pretrained(model_name) + cfg.num_hidden_layers = 4 + cfg.hidden_size //= 8 + cfg.intermediate_size //= 8 + cfg.num_attention_heads //= 2 + cfg.num_key_value_heads //= 2 + cfg._attn_implementation = attn_impl + cfg.use_fp8_kv_scale = fp8_kv_enabled + cfg.cla_kv_cache_map = {0:0, 1:1, 2:1, 3:0} + cfg.palu_kv_compression_enabled = False + cfg.use_cache = False + cfg.debug_kv_sharing = True + cfg.output_attentions = attn_impl == "eager" + + x = torch.arange(32, device="cuda").view(1,-1) + + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16) + + assert model.config.use_fp8_kv_scale == fp8_kv_enabled + assert model.config.cla_kv_cache_map == {0:0, 1:1, 2:1, 3:0} + assert model.config.use_cache == False + + out_eager = model(x) + + if attn_impl == "eager": + assert torch.equal(out_eager.attentions[0], out_eager.attentions[3]) + assert torch.equal(out_eager.attentions[1], out_eager.attentions[2]) + assert not torch.equal(out_eager.attentions[0], out_eager.attentions[1]) + + attn_outputs = [l.self_attn.debug_cla_attn_output for l in model.model.layers] + assert len(attn_outputs) == 4 + assert torch.equal(attn_outputs[0], attn_outputs[3]) + assert torch.equal(attn_outputs[1], attn_outputs[2]) + assert not torch.equal(attn_outputs[0], attn_outputs[1]) + + def test_training_with_cla(self): + """ + Test that parameters are updated correctly: + + layer 0 k_proj, v_proj, k_scale, v_scale should be updated + layer 1 k_proj, v_proj shouldn't be updated + """ + model_name = "Qwen/Qwen2.5-32B-Instruct" + attn_impl = "eager" + fp8_kv_enabled = True + + cfg = AutoConfig.from_pretrained(model_name) + cfg.num_hidden_layers = 2 + cfg.hidden_size //= 8 + cfg.intermediate_size //= 8 + cfg.num_attention_heads //= 2 + cfg.num_key_value_heads //= 2 + cfg._attn_implementation = attn_impl + cfg.use_fp8_kv_scale = fp8_kv_enabled + cfg.cla_kv_cache_map = {0: 0, 1: 0} + cfg.palu_kv_compression_enabled = False + cfg.use_cache = False + cfg.debug_kv_sharing = True + + model = AutoModelForCausalLM.from_config(cfg) + model.to(device="cuda", dtype=torch.bfloat16) + + model.train() + + # Perform a forward and backward pass + x = torch.arange(32, device="cuda").view(1, -1) + output = model(x) + loss = output.logits.sum() + loss.backward() + + # Check which parameters have been updated + for name, param in model.named_parameters(): + if param.grad is not None: + updated = torch.any(param.grad != 0) + if "layers.0" in name: + if any(proj in name for proj in ["k_proj", "v_proj"]): + self.assertTrue(updated, f"{name} should be updated") + if fp8_kv_enabled and any(scale in name for scale in ["k_scale", "v_scale"]): + self.assertTrue(updated, f"{name} should be updated") + elif "layers.1" in name: + if any(proj in name for proj in ["k_proj", "v_proj"]): + self.assertFalse(updated, f"{name} should not be updated") + else: + self.assertTrue(updated, f"{name} should be updated") \ No newline at end of file