Specialize code

This commit is contained in:
Joel Lamy-Poirier 2023-05-18 17:04:03 -04:00
parent 7c11ceba6c
commit a6dd19b042
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -120,15 +120,14 @@ class GPTBigCodeAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
self.mask_value = None
assert config.multi_query
assert config.attention_softmax_in_fp32
assert config.scale_attention_softmax_in_fp32
self.multi_query = config.multi_query
self.seq_dim = -2 if self.multi_query else -1
self.flash_attention = config.flash_attention
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.kv_heads = 1 if self.multi_query else self.num_heads
self.kv_dim = self.kv_heads * self.head_dim
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
@ -140,10 +139,6 @@ class GPTBigCodeAttention(nn.Module):
self.is_cross_attention = is_cross_attention
self.layer_idx = layer_idx
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = (
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
)
self.fused_softmax = config.fused_softmax
# KV caching and padding
@ -155,7 +150,7 @@ class GPTBigCodeAttention(nn.Module):
if self.is_cross_attention:
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
@ -168,9 +163,6 @@ class GPTBigCodeAttention(nn.Module):
"Flash Attention requires `flash_attn` and `einops`. "
"To install, run `pip install flash-attn einops`."
)
if not self.multi_query:
# TODO: Flash Attention is implemented but not tested for MHA
raise ValueError("Flash Attention is not supported with multi-head attention.")
def _get_mask_value(self, device, dtype):
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
@ -178,41 +170,29 @@ class GPTBigCodeAttention(nn.Module):
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
return self.mask_value
def _attn(self, query, key, value, attention_mask, head_mask=None):
def _attn(self, query, key, value, attention_mask):
dtype = query.dtype
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
softmax_dtype = torch.float32
upcast = dtype != softmax_dtype
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale**-1
if self.scale_attn_weights:
scale_factor /= self.head_dim**0.5
# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
# (batch_size, query_length, num_heads * head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key_length = key.size(-2)
key = key.transpose(-1, -2)
if self.multi_query:
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
else:
# (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
# -> (batch_size, num_heads, query_length, key_length)
query_length = query_shape[2]
attn_shape = (batch_size, self.num_heads, query_length, key_length)
attn_view = (batch_size * self.num_heads, query_length, key_length)
# Always copies
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
# No copy when layer_past is provided.
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
if query.device.type == "cpu":
@ -237,32 +217,17 @@ class GPTBigCodeAttention(nn.Module):
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
if self.multi_query:
head_mask = head_mask.transpose(1, 2)
attn_weights = attn_weights * head_mask
if self.multi_query:
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
else:
attn_output = torch.matmul(attn_weights, value)
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
return attn_output, attn_weights
def _attn_flash(self, query, key, value, attention_mask, head_mask=None):
if head_mask is not None:
raise NotImplementedError("Head mask is not supported with flash attention.")
def _attn_flash(self, query, key, value, attention_mask):
query_shape = query.shape
attn_shape = query_shape[0], self.num_heads, self.head_dim
query = query.view(attn_shape)
if self.multi_query:
key = key.unsqueeze(1).expand(attn_shape)
value = value.unsqueeze(1).expand(attn_shape)
else:
key = key.view(attn_shape)
value = value.view(attn_shape)
key = key.unsqueeze(1).expand(attn_shape)
value = value.unsqueeze(1).expand(attn_shape)
sequence_lengths, padding_index, _, max_sequence_length = attention_mask
@ -285,20 +250,11 @@ class GPTBigCodeAttention(nn.Module):
def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length):
batch_size = kv_cache.size(-1)
assert not self.training
if self.multi_query:
allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
)
allocated_kv_cache[:, :key_length].copy_(kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
else:
allocated_kv_cache = torch.empty(
[batch_size, self.num_heads, allocate_key_length, self.head_dim],
dtype=kv_cache.dtype,
device=kv_cache.device,
)
allocated_kv_cache[:, :, key_length].copy_(kv_cache)
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
)
allocated_kv_cache[:, :key_length].copy_(kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
return allocated_kv_cache, padded_kv_cache
def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask):
@ -308,16 +264,12 @@ class GPTBigCodeAttention(nn.Module):
if flash_attention and use_cache:
_, padding_index, batch_size, max_sequence_length = attention_mask
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
if not self.multi_query:
current_kv_cache = current_kv_cache.view(
batch_size, max_sequence_length, self.num_heads, 2 * self.head_dim
).transpose(1, 2)
else:
current_kv_cache = key_value
# Calculate dimensions and recover layer_past
batch_size = current_kv_cache.size(0)
query_length = current_kv_cache.size(self.seq_dim)
query_length = current_kv_cache.size(-2)
if layer_past is None:
allocated_kv_cache, last_key_length = None, 0
last_kv_cache = None
@ -325,50 +277,31 @@ class GPTBigCodeAttention(nn.Module):
allocated_key_length = key_length
else:
allocated_kv_cache, last_key_length = layer_past
last_kv_cache = (
allocated_kv_cache[:, :last_key_length]
if self.multi_query
else allocated_kv_cache[:, :, :last_key_length]
)
last_kv_cache = allocated_kv_cache[:, :last_key_length]
key_length = query_length + last_key_length
allocated_key_length = allocated_kv_cache.size(self.seq_dim)
allocated_key_length = allocated_kv_cache.size(-2)
padded_key_length = key_length if flash_attention else attention_mask.size(-1)
allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length)
# Re-allocate kv cache and copy last value
if allocate_key_length > allocated_key_length:
if self.multi_query:
allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, 2 * self.head_dim],
dtype=current_kv_cache.dtype,
device=current_kv_cache.device,
)
if layer_past is not None:
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
if allocate_key_length > key_length:
# Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
else:
allocated_kv_cache = torch.empty(
[batch_size, self.num_heads, allocate_key_length, 2 * self.head_dim],
dtype=current_kv_cache.dtype,
device=current_kv_cache.device,
)
if layer_past is not None:
allocated_kv_cache[:, :, :last_key_length].copy_(last_kv_cache)
if allocate_key_length > key_length:
allocated_kv_cache[:, :, key_length:, self.head_dim :].zero_()
allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, 2 * self.head_dim],
dtype=current_kv_cache.dtype,
device=current_kv_cache.device,
)
if layer_past is not None:
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
if allocate_key_length > key_length:
# Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
# Copy the new values.
if allocate_key_length > allocated_key_length or layer_past is not None:
if self.multi_query:
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
else:
allocated_kv_cache[:, :, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
if not flash_attention:
# Use the merged KV cache.
# Not needed when layer_past is None but frees some memory.
@ -387,7 +320,6 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
@ -395,18 +327,7 @@ class GPTBigCodeAttention(nn.Module):
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
flash_attention = self.flash_attention and layer_past is None
if self.multi_query or flash_attention:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
if self._tuple_cache_format:
# present = (allocated_kv_cache, key_length)
@ -420,11 +341,9 @@ class GPTBigCodeAttention(nn.Module):
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)(
query, key, value, attention_mask, head_mask
query, key, value, attention_mask
)
if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
@ -432,9 +351,8 @@ class GPTBigCodeAttention(nn.Module):
if output_attentions:
if flash_attention:
raise ValueError("`output_attentions` is not supported with Flash Attention.")
if self.multi_query:
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
attn_weights = attn_weights.transpose(1, 2)
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs # a, present, (attentions)
@ -478,7 +396,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states: Optional[Tuple[torch.Tensor]],
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
@ -495,7 +412,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
@ -570,7 +486,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.multi_query = config.multi_query
assert config.multi_query
self.embed_dim = config.hidden_size
if config.add_cross_attention:
@ -624,7 +540,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# MQA models: (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
return attention_mask.unsqueeze(2 if self.multi_query else 1)
return attention_mask.unsqueeze(2)
def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device):
if position_ids is not None:
@ -646,7 +562,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
@ -662,7 +577,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
assert attention_mask is not None
assert token_type_ids is None
assert position_ids is not None
assert head_mask is None
assert inputs_embeds is None
assert encoder_hidden_states is None
assert encoder_attention_mask is None
@ -729,12 +643,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
if encoder_hidden_states is not None or encoder_attention_mask is not None:
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
@ -776,14 +684,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
hidden_states,
None,
attention_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
@ -874,7 +780,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
@ -898,7 +803,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
@ -936,15 +840,3 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)