mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Specialize code
This commit is contained in:
parent
7c11ceba6c
commit
a6dd19b042
@ -120,15 +120,14 @@ class GPTBigCodeAttention(nn.Module):
|
||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||
super().__init__()
|
||||
self.mask_value = None
|
||||
assert config.multi_query
|
||||
assert config.attention_softmax_in_fp32
|
||||
assert config.scale_attention_softmax_in_fp32
|
||||
|
||||
self.multi_query = config.multi_query
|
||||
self.seq_dim = -2 if self.multi_query else -1
|
||||
self.flash_attention = config.flash_attention
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.kv_heads = 1 if self.multi_query else self.num_heads
|
||||
self.kv_dim = self.kv_heads * self.head_dim
|
||||
self.split_size = self.embed_dim
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
@ -140,10 +139,6 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
||||
self.scale_attention_softmax_in_fp32 = (
|
||||
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
|
||||
)
|
||||
self.fused_softmax = config.fused_softmax
|
||||
|
||||
# KV caching and padding
|
||||
@ -155,7 +150,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
if self.is_cross_attention:
|
||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
|
||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
|
||||
|
||||
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
@ -168,9 +163,6 @@ class GPTBigCodeAttention(nn.Module):
|
||||
"Flash Attention requires `flash_attn` and `einops`. "
|
||||
"To install, run `pip install flash-attn einops`."
|
||||
)
|
||||
if not self.multi_query:
|
||||
# TODO: Flash Attention is implemented but not tested for MHA
|
||||
raise ValueError("Flash Attention is not supported with multi-head attention.")
|
||||
|
||||
def _get_mask_value(self, device, dtype):
|
||||
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
||||
@ -178,41 +170,29 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
||||
return self.mask_value
|
||||
|
||||
def _attn(self, query, key, value, attention_mask, head_mask=None):
|
||||
def _attn(self, query, key, value, attention_mask):
|
||||
dtype = query.dtype
|
||||
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
|
||||
softmax_dtype = torch.float32
|
||||
upcast = dtype != softmax_dtype
|
||||
|
||||
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
|
||||
unscale = self.layer_idx + 1 if upcast else 1
|
||||
scale_factor = unscale**-1
|
||||
if self.scale_attn_weights:
|
||||
scale_factor /= self.head_dim**0.5
|
||||
|
||||
# MQA models: (batch_size, query_length, num_heads * head_dim)
|
||||
# MHA models: (batch_size, num_heads, query_length, head_dim)
|
||||
# (batch_size, query_length, num_heads * head_dim)
|
||||
query_shape = query.shape
|
||||
batch_size = query_shape[0]
|
||||
key_length = key.size(-2)
|
||||
|
||||
key = key.transpose(-1, -2)
|
||||
if self.multi_query:
|
||||
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
||||
# -> (batch_size, query_length, num_heads, key_length)
|
||||
query_length = query_shape[1]
|
||||
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
||||
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
||||
# No copy needed for MQA 2, or when layer_past is provided.
|
||||
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
||||
else:
|
||||
# (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
|
||||
# -> (batch_size, num_heads, query_length, key_length)
|
||||
query_length = query_shape[2]
|
||||
attn_shape = (batch_size, self.num_heads, query_length, key_length)
|
||||
attn_view = (batch_size * self.num_heads, query_length, key_length)
|
||||
# Always copies
|
||||
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||
# No copy when layer_past is provided.
|
||||
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
|
||||
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
||||
# -> (batch_size, query_length, num_heads, key_length)
|
||||
query_length = query_shape[1]
|
||||
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
||||
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
||||
# No copy needed for MQA 2, or when layer_past is provided.
|
||||
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
||||
|
||||
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
|
||||
if query.device.type == "cpu":
|
||||
@ -237,32 +217,17 @@ class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
if self.multi_query:
|
||||
head_mask = head_mask.transpose(1, 2)
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
if self.multi_query:
|
||||
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
||||
else:
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _attn_flash(self, query, key, value, attention_mask, head_mask=None):
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError("Head mask is not supported with flash attention.")
|
||||
def _attn_flash(self, query, key, value, attention_mask):
|
||||
|
||||
query_shape = query.shape
|
||||
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
||||
query = query.view(attn_shape)
|
||||
if self.multi_query:
|
||||
key = key.unsqueeze(1).expand(attn_shape)
|
||||
value = value.unsqueeze(1).expand(attn_shape)
|
||||
else:
|
||||
key = key.view(attn_shape)
|
||||
value = value.view(attn_shape)
|
||||
key = key.unsqueeze(1).expand(attn_shape)
|
||||
value = value.unsqueeze(1).expand(attn_shape)
|
||||
|
||||
sequence_lengths, padding_index, _, max_sequence_length = attention_mask
|
||||
|
||||
@ -285,20 +250,11 @@ class GPTBigCodeAttention(nn.Module):
|
||||
def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length):
|
||||
batch_size = kv_cache.size(-1)
|
||||
assert not self.training
|
||||
if self.multi_query:
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
|
||||
)
|
||||
allocated_kv_cache[:, :key_length].copy_(kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||
else:
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, self.num_heads, allocate_key_length, self.head_dim],
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
allocated_kv_cache[:, :, key_length].copy_(kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
|
||||
)
|
||||
allocated_kv_cache[:, :key_length].copy_(kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||
return allocated_kv_cache, padded_kv_cache
|
||||
|
||||
def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask):
|
||||
@ -308,16 +264,12 @@ class GPTBigCodeAttention(nn.Module):
|
||||
if flash_attention and use_cache:
|
||||
_, padding_index, batch_size, max_sequence_length = attention_mask
|
||||
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
|
||||
if not self.multi_query:
|
||||
current_kv_cache = current_kv_cache.view(
|
||||
batch_size, max_sequence_length, self.num_heads, 2 * self.head_dim
|
||||
).transpose(1, 2)
|
||||
else:
|
||||
current_kv_cache = key_value
|
||||
|
||||
# Calculate dimensions and recover layer_past
|
||||
batch_size = current_kv_cache.size(0)
|
||||
query_length = current_kv_cache.size(self.seq_dim)
|
||||
query_length = current_kv_cache.size(-2)
|
||||
if layer_past is None:
|
||||
allocated_kv_cache, last_key_length = None, 0
|
||||
last_kv_cache = None
|
||||
@ -325,50 +277,31 @@ class GPTBigCodeAttention(nn.Module):
|
||||
allocated_key_length = key_length
|
||||
else:
|
||||
allocated_kv_cache, last_key_length = layer_past
|
||||
last_kv_cache = (
|
||||
allocated_kv_cache[:, :last_key_length]
|
||||
if self.multi_query
|
||||
else allocated_kv_cache[:, :, :last_key_length]
|
||||
)
|
||||
last_kv_cache = allocated_kv_cache[:, :last_key_length]
|
||||
key_length = query_length + last_key_length
|
||||
allocated_key_length = allocated_kv_cache.size(self.seq_dim)
|
||||
allocated_key_length = allocated_kv_cache.size(-2)
|
||||
|
||||
padded_key_length = key_length if flash_attention else attention_mask.size(-1)
|
||||
allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length)
|
||||
|
||||
# Re-allocate kv cache and copy last value
|
||||
if allocate_key_length > allocated_key_length:
|
||||
if self.multi_query:
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, allocate_key_length, 2 * self.head_dim],
|
||||
dtype=current_kv_cache.dtype,
|
||||
device=current_kv_cache.device,
|
||||
)
|
||||
if layer_past is not None:
|
||||
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
|
||||
if allocate_key_length > key_length:
|
||||
# Nans in `value` can propagate through the matrix multiplication,
|
||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
|
||||
else:
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, self.num_heads, allocate_key_length, 2 * self.head_dim],
|
||||
dtype=current_kv_cache.dtype,
|
||||
device=current_kv_cache.device,
|
||||
)
|
||||
if layer_past is not None:
|
||||
allocated_kv_cache[:, :, :last_key_length].copy_(last_kv_cache)
|
||||
if allocate_key_length > key_length:
|
||||
allocated_kv_cache[:, :, key_length:, self.head_dim :].zero_()
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, allocate_key_length, 2 * self.head_dim],
|
||||
dtype=current_kv_cache.dtype,
|
||||
device=current_kv_cache.device,
|
||||
)
|
||||
if layer_past is not None:
|
||||
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
|
||||
if allocate_key_length > key_length:
|
||||
# Nans in `value` can propagate through the matrix multiplication,
|
||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
|
||||
|
||||
# Copy the new values.
|
||||
if allocate_key_length > allocated_key_length or layer_past is not None:
|
||||
if self.multi_query:
|
||||
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||
else:
|
||||
allocated_kv_cache[:, :, last_key_length:key_length].copy_(current_kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
|
||||
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||
if not flash_attention:
|
||||
# Use the merged KV cache.
|
||||
# Not needed when layer_past is None but frees some memory.
|
||||
@ -387,7 +320,6 @@ class GPTBigCodeAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
@ -395,18 +327,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
]:
|
||||
flash_attention = self.flash_attention and layer_past is None
|
||||
if self.multi_query or flash_attention:
|
||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1)
|
||||
else:
|
||||
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
|
||||
# i.e., the memory layout is not the same as GPT2.
|
||||
# This makes the concatenation with past_key_value more efficient.
|
||||
query, key_value = (
|
||||
self.c_attn(hidden_states)
|
||||
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
||||
)
|
||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
||||
|
||||
if self._tuple_cache_format:
|
||||
# present = (allocated_kv_cache, key_length)
|
||||
@ -420,11 +341,9 @@ class GPTBigCodeAttention(nn.Module):
|
||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)(
|
||||
query, key, value, attention_mask, head_mask
|
||||
query, key, value, attention_mask
|
||||
)
|
||||
|
||||
if not self.multi_query:
|
||||
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
@ -432,9 +351,8 @@ class GPTBigCodeAttention(nn.Module):
|
||||
if output_attentions:
|
||||
if flash_attention:
|
||||
raise ValueError("`output_attentions` is not supported with Flash Attention.")
|
||||
if self.multi_query:
|
||||
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
||||
attn_weights = attn_weights.transpose(1, 2)
|
||||
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
||||
attn_weights = attn_weights.transpose(1, 2)
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
@ -478,7 +396,6 @@ class GPTBigCodeBlock(nn.Module):
|
||||
hidden_states: Optional[Tuple[torch.Tensor]],
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
@ -495,7 +412,6 @@ class GPTBigCodeBlock(nn.Module):
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -570,7 +486,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.multi_query = config.multi_query
|
||||
assert config.multi_query
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
if config.add_cross_attention:
|
||||
@ -624,7 +540,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
|
||||
# MQA models: (batch_size, query_length, n_heads, key_length)
|
||||
# MHA models: (batch_size, n_heads, query_length, key_length)
|
||||
return attention_mask.unsqueeze(2 if self.multi_query else 1)
|
||||
return attention_mask.unsqueeze(2)
|
||||
|
||||
def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device):
|
||||
if position_ids is not None:
|
||||
@ -646,7 +562,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
@ -662,7 +577,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
assert attention_mask is not None
|
||||
assert token_type_ids is None
|
||||
assert position_ids is not None
|
||||
assert head_mask is None
|
||||
assert inputs_embeds is None
|
||||
assert encoder_hidden_states is None
|
||||
assert encoder_attention_mask is None
|
||||
@ -729,12 +643,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
if encoder_hidden_states is not None or encoder_attention_mask is not None:
|
||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
@ -776,14 +684,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -874,7 +780,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
@ -898,7 +803,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
@ -936,15 +840,3 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(
|
||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user