diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py index a39263e4..078a8202 100644 --- a/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py @@ -120,15 +120,14 @@ class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.mask_value = None + assert config.multi_query + assert config.attention_softmax_in_fp32 + assert config.scale_attention_softmax_in_fp32 - self.multi_query = config.multi_query - self.seq_dim = -2 if self.multi_query else -1 self.flash_attention = config.flash_attention self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.kv_heads = 1 if self.multi_query else self.num_heads - self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -140,10 +139,6 @@ class GPTBigCodeAttention(nn.Module): self.is_cross_attention = is_cross_attention self.layer_idx = layer_idx - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - self.scale_attention_softmax_in_fp32 = ( - config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 - ) self.fused_softmax = config.fused_softmax # KV caching and padding @@ -155,7 +150,7 @@ class GPTBigCodeAttention(nn.Module): if self.is_cross_attention: raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim) self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) @@ -168,9 +163,6 @@ class GPTBigCodeAttention(nn.Module): "Flash Attention requires `flash_attn` and `einops`. " "To install, run `pip install flash-attn einops`." ) - if not self.multi_query: - # TODO: Flash Attention is implemented but not tested for MHA - raise ValueError("Flash Attention is not supported with multi-head attention.") def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. @@ -178,41 +170,29 @@ class GPTBigCodeAttention(nn.Module): self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) return self.mask_value - def _attn(self, query, key, value, attention_mask, head_mask=None): + def _attn(self, query, key, value, attention_mask): dtype = query.dtype - softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype + softmax_dtype = torch.float32 upcast = dtype != softmax_dtype - unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + unscale = self.layer_idx + 1 if upcast else 1 scale_factor = unscale**-1 if self.scale_attn_weights: scale_factor /= self.head_dim**0.5 - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) + # (batch_size, query_length, num_heads * head_dim) query_shape = query.shape batch_size = query_shape[0] key_length = key.size(-2) key = key.transpose(-1, -2) - if self.multi_query: - # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) - # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] - attn_shape = (batch_size, query_length, self.num_heads, key_length) - attn_view = (batch_size, query_length * self.num_heads, key_length) - # No copy needed for MQA 2, or when layer_past is provided. - query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) - else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": @@ -237,32 +217,17 @@ class GPTBigCodeAttention(nn.Module): attn_weights = self.attn_dropout(attn_weights) - # Mask heads if we want to - if head_mask is not None: - if self.multi_query: - head_mask = head_mask.transpose(1, 2) - attn_weights = attn_weights * head_mask - - if self.multi_query: - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) - else: - attn_output = torch.matmul(attn_weights, value) + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) return attn_output, attn_weights - def _attn_flash(self, query, key, value, attention_mask, head_mask=None): - if head_mask is not None: - raise NotImplementedError("Head mask is not supported with flash attention.") + def _attn_flash(self, query, key, value, attention_mask): query_shape = query.shape attn_shape = query_shape[0], self.num_heads, self.head_dim query = query.view(attn_shape) - if self.multi_query: - key = key.unsqueeze(1).expand(attn_shape) - value = value.unsqueeze(1).expand(attn_shape) - else: - key = key.view(attn_shape) - value = value.view(attn_shape) + key = key.unsqueeze(1).expand(attn_shape) + value = value.unsqueeze(1).expand(attn_shape) sequence_lengths, padding_index, _, max_sequence_length = attention_mask @@ -285,20 +250,11 @@ class GPTBigCodeAttention(nn.Module): def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length): batch_size = kv_cache.size(-1) assert not self.training - if self.multi_query: - allocated_kv_cache = torch.empty( - [batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device - ) - allocated_kv_cache[:, :key_length].copy_(kv_cache) - padded_kv_cache = allocated_kv_cache[:, :padded_key_length] - else: - allocated_kv_cache = torch.empty( - [batch_size, self.num_heads, allocate_key_length, self.head_dim], - dtype=kv_cache.dtype, - device=kv_cache.device, - ) - allocated_kv_cache[:, :, key_length].copy_(kv_cache) - padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length] + allocated_kv_cache = torch.empty( + [batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device + ) + allocated_kv_cache[:, :key_length].copy_(kv_cache) + padded_kv_cache = allocated_kv_cache[:, :padded_key_length] return allocated_kv_cache, padded_kv_cache def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): @@ -308,16 +264,12 @@ class GPTBigCodeAttention(nn.Module): if flash_attention and use_cache: _, padding_index, batch_size, max_sequence_length = attention_mask current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) - if not self.multi_query: - current_kv_cache = current_kv_cache.view( - batch_size, max_sequence_length, self.num_heads, 2 * self.head_dim - ).transpose(1, 2) else: current_kv_cache = key_value # Calculate dimensions and recover layer_past batch_size = current_kv_cache.size(0) - query_length = current_kv_cache.size(self.seq_dim) + query_length = current_kv_cache.size(-2) if layer_past is None: allocated_kv_cache, last_key_length = None, 0 last_kv_cache = None @@ -325,50 +277,31 @@ class GPTBigCodeAttention(nn.Module): allocated_key_length = key_length else: allocated_kv_cache, last_key_length = layer_past - last_kv_cache = ( - allocated_kv_cache[:, :last_key_length] - if self.multi_query - else allocated_kv_cache[:, :, :last_key_length] - ) + last_kv_cache = allocated_kv_cache[:, :last_key_length] key_length = query_length + last_key_length - allocated_key_length = allocated_kv_cache.size(self.seq_dim) + allocated_key_length = allocated_kv_cache.size(-2) padded_key_length = key_length if flash_attention else attention_mask.size(-1) allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length) # Re-allocate kv cache and copy last value if allocate_key_length > allocated_key_length: - if self.multi_query: - allocated_kv_cache = torch.empty( - [batch_size, allocate_key_length, 2 * self.head_dim], - dtype=current_kv_cache.dtype, - device=current_kv_cache.device, - ) - if layer_past is not None: - allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) - if allocate_key_length > key_length: - # Nans in `value` can propagate through the matrix multiplication, - # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) - allocated_kv_cache[:, key_length:, self.head_dim :].zero_() - else: - allocated_kv_cache = torch.empty( - [batch_size, self.num_heads, allocate_key_length, 2 * self.head_dim], - dtype=current_kv_cache.dtype, - device=current_kv_cache.device, - ) - if layer_past is not None: - allocated_kv_cache[:, :, :last_key_length].copy_(last_kv_cache) - if allocate_key_length > key_length: - allocated_kv_cache[:, :, key_length:, self.head_dim :].zero_() + allocated_kv_cache = torch.empty( + [batch_size, allocate_key_length, 2 * self.head_dim], + dtype=current_kv_cache.dtype, + device=current_kv_cache.device, + ) + if layer_past is not None: + allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) + if allocate_key_length > key_length: + # Nans in `value` can propagate through the matrix multiplication, + # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) + allocated_kv_cache[:, key_length:, self.head_dim :].zero_() # Copy the new values. if allocate_key_length > allocated_key_length or layer_past is not None: - if self.multi_query: - allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) - padded_kv_cache = allocated_kv_cache[:, :padded_key_length] - else: - allocated_kv_cache[:, :, last_key_length:key_length].copy_(current_kv_cache) - padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length] + allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) + padded_kv_cache = allocated_kv_cache[:, :padded_key_length] if not flash_attention: # Use the merged KV cache. # Not needed when layer_past is None but frees some memory. @@ -387,7 +320,6 @@ class GPTBigCodeAttention(nn.Module): hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Union[ @@ -395,18 +327,7 @@ class GPTBigCodeAttention(nn.Module): Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: flash_attention = self.flash_attention and layer_past is None - if self.multi_query or flash_attention: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) if self._tuple_cache_format: # present = (allocated_kv_cache, key_length) @@ -420,11 +341,9 @@ class GPTBigCodeAttention(nn.Module): key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)( - query, key, value, attention_mask, head_mask + query, key, value, attention_mask ) - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -432,9 +351,8 @@ class GPTBigCodeAttention(nn.Module): if output_attentions: if flash_attention: raise ValueError("`output_attentions` is not supported with Flash Attention.") - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) outputs += (attn_weights,) return outputs # a, present, (attentions) @@ -478,7 +396,6 @@ class GPTBigCodeBlock(nn.Module): hidden_states: Optional[Tuple[torch.Tensor]], layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, @@ -495,7 +412,6 @@ class GPTBigCodeBlock(nn.Module): hidden_states, layer_past=layer_past, attention_mask=attention_mask, - head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) @@ -570,7 +486,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): def __init__(self, config): super().__init__(config) - self.multi_query = config.multi_query + assert config.multi_query self.embed_dim = config.hidden_size if config.add_cross_attention: @@ -624,7 +540,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): # MQA models: (batch_size, query_length, n_heads, key_length) # MHA models: (batch_size, n_heads, query_length, key_length) - return attention_mask.unsqueeze(2 if self.multi_query else 1) + return attention_mask.unsqueeze(2) def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device): if position_ids is not None: @@ -646,7 +562,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, @@ -662,7 +577,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): assert attention_mask is not None assert token_type_ids is None assert position_ids is not None - assert head_mask is None assert inputs_embeds is None assert encoder_hidden_states is None assert encoder_attention_mask is None @@ -729,12 +643,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): if encoder_hidden_states is not None or encoder_attention_mask is not None: raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) @@ -776,14 +684,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): hidden_states, None, attention_mask, - head_mask[i], ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, - head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, ) @@ -874,7 +780,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, @@ -898,7 +803,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, - head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, @@ -936,15 +840,3 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, ) - - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) -