mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Specialize code
This commit is contained in:
parent
7c11ceba6c
commit
a6dd19b042
@ -120,15 +120,14 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mask_value = None
|
self.mask_value = None
|
||||||
|
assert config.multi_query
|
||||||
|
assert config.attention_softmax_in_fp32
|
||||||
|
assert config.scale_attention_softmax_in_fp32
|
||||||
|
|
||||||
self.multi_query = config.multi_query
|
|
||||||
self.seq_dim = -2 if self.multi_query else -1
|
|
||||||
self.flash_attention = config.flash_attention
|
self.flash_attention = config.flash_attention
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
self.kv_heads = 1 if self.multi_query else self.num_heads
|
|
||||||
self.kv_dim = self.kv_heads * self.head_dim
|
|
||||||
self.split_size = self.embed_dim
|
self.split_size = self.embed_dim
|
||||||
if self.head_dim * self.num_heads != self.embed_dim:
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -140,10 +139,6 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
self.is_cross_attention = is_cross_attention
|
self.is_cross_attention = is_cross_attention
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
|
||||||
self.scale_attention_softmax_in_fp32 = (
|
|
||||||
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
|
|
||||||
)
|
|
||||||
self.fused_softmax = config.fused_softmax
|
self.fused_softmax = config.fused_softmax
|
||||||
|
|
||||||
# KV caching and padding
|
# KV caching and padding
|
||||||
@ -155,7 +150,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
|
|
||||||
if self.is_cross_attention:
|
if self.is_cross_attention:
|
||||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
|
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
|
||||||
|
|
||||||
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
|
||||||
@ -168,9 +163,6 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
"Flash Attention requires `flash_attn` and `einops`. "
|
"Flash Attention requires `flash_attn` and `einops`. "
|
||||||
"To install, run `pip install flash-attn einops`."
|
"To install, run `pip install flash-attn einops`."
|
||||||
)
|
)
|
||||||
if not self.multi_query:
|
|
||||||
# TODO: Flash Attention is implemented but not tested for MHA
|
|
||||||
raise ValueError("Flash Attention is not supported with multi-head attention.")
|
|
||||||
|
|
||||||
def _get_mask_value(self, device, dtype):
|
def _get_mask_value(self, device, dtype):
|
||||||
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
||||||
@ -178,41 +170,29 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
||||||
return self.mask_value
|
return self.mask_value
|
||||||
|
|
||||||
def _attn(self, query, key, value, attention_mask, head_mask=None):
|
def _attn(self, query, key, value, attention_mask):
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
|
softmax_dtype = torch.float32
|
||||||
upcast = dtype != softmax_dtype
|
upcast = dtype != softmax_dtype
|
||||||
|
|
||||||
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
|
unscale = self.layer_idx + 1 if upcast else 1
|
||||||
scale_factor = unscale**-1
|
scale_factor = unscale**-1
|
||||||
if self.scale_attn_weights:
|
if self.scale_attn_weights:
|
||||||
scale_factor /= self.head_dim**0.5
|
scale_factor /= self.head_dim**0.5
|
||||||
|
|
||||||
# MQA models: (batch_size, query_length, num_heads * head_dim)
|
# (batch_size, query_length, num_heads * head_dim)
|
||||||
# MHA models: (batch_size, num_heads, query_length, head_dim)
|
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
batch_size = query_shape[0]
|
batch_size = query_shape[0]
|
||||||
key_length = key.size(-2)
|
key_length = key.size(-2)
|
||||||
|
|
||||||
key = key.transpose(-1, -2)
|
key = key.transpose(-1, -2)
|
||||||
if self.multi_query:
|
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
||||||
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
# -> (batch_size, query_length, num_heads, key_length)
|
||||||
# -> (batch_size, query_length, num_heads, key_length)
|
query_length = query_shape[1]
|
||||||
query_length = query_shape[1]
|
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
||||||
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
||||||
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
# No copy needed for MQA 2, or when layer_past is provided.
|
||||||
# No copy needed for MQA 2, or when layer_past is provided.
|
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
||||||
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
|
||||||
else:
|
|
||||||
# (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
|
|
||||||
# -> (batch_size, num_heads, query_length, key_length)
|
|
||||||
query_length = query_shape[2]
|
|
||||||
attn_shape = (batch_size, self.num_heads, query_length, key_length)
|
|
||||||
attn_view = (batch_size * self.num_heads, query_length, key_length)
|
|
||||||
# Always copies
|
|
||||||
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
|
||||||
# No copy when layer_past is provided.
|
|
||||||
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
|
|
||||||
|
|
||||||
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
|
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
|
||||||
if query.device.type == "cpu":
|
if query.device.type == "cpu":
|
||||||
@ -237,32 +217,17 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = self.attn_dropout(attn_weights)
|
attn_weights = self.attn_dropout(attn_weights)
|
||||||
|
|
||||||
# Mask heads if we want to
|
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
||||||
if head_mask is not None:
|
|
||||||
if self.multi_query:
|
|
||||||
head_mask = head_mask.transpose(1, 2)
|
|
||||||
attn_weights = attn_weights * head_mask
|
|
||||||
|
|
||||||
if self.multi_query:
|
|
||||||
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
|
||||||
else:
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
|
||||||
|
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
def _attn_flash(self, query, key, value, attention_mask, head_mask=None):
|
def _attn_flash(self, query, key, value, attention_mask):
|
||||||
if head_mask is not None:
|
|
||||||
raise NotImplementedError("Head mask is not supported with flash attention.")
|
|
||||||
|
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
||||||
query = query.view(attn_shape)
|
query = query.view(attn_shape)
|
||||||
if self.multi_query:
|
key = key.unsqueeze(1).expand(attn_shape)
|
||||||
key = key.unsqueeze(1).expand(attn_shape)
|
value = value.unsqueeze(1).expand(attn_shape)
|
||||||
value = value.unsqueeze(1).expand(attn_shape)
|
|
||||||
else:
|
|
||||||
key = key.view(attn_shape)
|
|
||||||
value = value.view(attn_shape)
|
|
||||||
|
|
||||||
sequence_lengths, padding_index, _, max_sequence_length = attention_mask
|
sequence_lengths, padding_index, _, max_sequence_length = attention_mask
|
||||||
|
|
||||||
@ -285,20 +250,11 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length):
|
def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length):
|
||||||
batch_size = kv_cache.size(-1)
|
batch_size = kv_cache.size(-1)
|
||||||
assert not self.training
|
assert not self.training
|
||||||
if self.multi_query:
|
allocated_kv_cache = torch.empty(
|
||||||
allocated_kv_cache = torch.empty(
|
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
|
||||||
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
|
)
|
||||||
)
|
allocated_kv_cache[:, :key_length].copy_(kv_cache)
|
||||||
allocated_kv_cache[:, :key_length].copy_(kv_cache)
|
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
|
||||||
else:
|
|
||||||
allocated_kv_cache = torch.empty(
|
|
||||||
[batch_size, self.num_heads, allocate_key_length, self.head_dim],
|
|
||||||
dtype=kv_cache.dtype,
|
|
||||||
device=kv_cache.device,
|
|
||||||
)
|
|
||||||
allocated_kv_cache[:, :, key_length].copy_(kv_cache)
|
|
||||||
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
|
|
||||||
return allocated_kv_cache, padded_kv_cache
|
return allocated_kv_cache, padded_kv_cache
|
||||||
|
|
||||||
def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask):
|
def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask):
|
||||||
@ -308,16 +264,12 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
if flash_attention and use_cache:
|
if flash_attention and use_cache:
|
||||||
_, padding_index, batch_size, max_sequence_length = attention_mask
|
_, padding_index, batch_size, max_sequence_length = attention_mask
|
||||||
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
|
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
|
||||||
if not self.multi_query:
|
|
||||||
current_kv_cache = current_kv_cache.view(
|
|
||||||
batch_size, max_sequence_length, self.num_heads, 2 * self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
else:
|
else:
|
||||||
current_kv_cache = key_value
|
current_kv_cache = key_value
|
||||||
|
|
||||||
# Calculate dimensions and recover layer_past
|
# Calculate dimensions and recover layer_past
|
||||||
batch_size = current_kv_cache.size(0)
|
batch_size = current_kv_cache.size(0)
|
||||||
query_length = current_kv_cache.size(self.seq_dim)
|
query_length = current_kv_cache.size(-2)
|
||||||
if layer_past is None:
|
if layer_past is None:
|
||||||
allocated_kv_cache, last_key_length = None, 0
|
allocated_kv_cache, last_key_length = None, 0
|
||||||
last_kv_cache = None
|
last_kv_cache = None
|
||||||
@ -325,50 +277,31 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
allocated_key_length = key_length
|
allocated_key_length = key_length
|
||||||
else:
|
else:
|
||||||
allocated_kv_cache, last_key_length = layer_past
|
allocated_kv_cache, last_key_length = layer_past
|
||||||
last_kv_cache = (
|
last_kv_cache = allocated_kv_cache[:, :last_key_length]
|
||||||
allocated_kv_cache[:, :last_key_length]
|
|
||||||
if self.multi_query
|
|
||||||
else allocated_kv_cache[:, :, :last_key_length]
|
|
||||||
)
|
|
||||||
key_length = query_length + last_key_length
|
key_length = query_length + last_key_length
|
||||||
allocated_key_length = allocated_kv_cache.size(self.seq_dim)
|
allocated_key_length = allocated_kv_cache.size(-2)
|
||||||
|
|
||||||
padded_key_length = key_length if flash_attention else attention_mask.size(-1)
|
padded_key_length = key_length if flash_attention else attention_mask.size(-1)
|
||||||
allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length)
|
allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length)
|
||||||
|
|
||||||
# Re-allocate kv cache and copy last value
|
# Re-allocate kv cache and copy last value
|
||||||
if allocate_key_length > allocated_key_length:
|
if allocate_key_length > allocated_key_length:
|
||||||
if self.multi_query:
|
allocated_kv_cache = torch.empty(
|
||||||
allocated_kv_cache = torch.empty(
|
[batch_size, allocate_key_length, 2 * self.head_dim],
|
||||||
[batch_size, allocate_key_length, 2 * self.head_dim],
|
dtype=current_kv_cache.dtype,
|
||||||
dtype=current_kv_cache.dtype,
|
device=current_kv_cache.device,
|
||||||
device=current_kv_cache.device,
|
)
|
||||||
)
|
if layer_past is not None:
|
||||||
if layer_past is not None:
|
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
|
||||||
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
|
if allocate_key_length > key_length:
|
||||||
if allocate_key_length > key_length:
|
# Nans in `value` can propagate through the matrix multiplication,
|
||||||
# Nans in `value` can propagate through the matrix multiplication,
|
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
|
||||||
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
|
|
||||||
else:
|
|
||||||
allocated_kv_cache = torch.empty(
|
|
||||||
[batch_size, self.num_heads, allocate_key_length, 2 * self.head_dim],
|
|
||||||
dtype=current_kv_cache.dtype,
|
|
||||||
device=current_kv_cache.device,
|
|
||||||
)
|
|
||||||
if layer_past is not None:
|
|
||||||
allocated_kv_cache[:, :, :last_key_length].copy_(last_kv_cache)
|
|
||||||
if allocate_key_length > key_length:
|
|
||||||
allocated_kv_cache[:, :, key_length:, self.head_dim :].zero_()
|
|
||||||
|
|
||||||
# Copy the new values.
|
# Copy the new values.
|
||||||
if allocate_key_length > allocated_key_length or layer_past is not None:
|
if allocate_key_length > allocated_key_length or layer_past is not None:
|
||||||
if self.multi_query:
|
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
|
||||||
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
|
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
|
||||||
else:
|
|
||||||
allocated_kv_cache[:, :, last_key_length:key_length].copy_(current_kv_cache)
|
|
||||||
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
|
|
||||||
if not flash_attention:
|
if not flash_attention:
|
||||||
# Use the merged KV cache.
|
# Use the merged KV cache.
|
||||||
# Not needed when layer_past is None but frees some memory.
|
# Not needed when layer_past is None but frees some memory.
|
||||||
@ -387,7 +320,6 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
@ -395,18 +327,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||||
]:
|
]:
|
||||||
flash_attention = self.flash_attention and layer_past is None
|
flash_attention = self.flash_attention and layer_past is None
|
||||||
if self.multi_query or flash_attention:
|
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
||||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1)
|
|
||||||
else:
|
|
||||||
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
|
|
||||||
# i.e., the memory layout is not the same as GPT2.
|
|
||||||
# This makes the concatenation with past_key_value more efficient.
|
|
||||||
query, key_value = (
|
|
||||||
self.c_attn(hidden_states)
|
|
||||||
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._tuple_cache_format:
|
if self._tuple_cache_format:
|
||||||
# present = (allocated_kv_cache, key_length)
|
# present = (allocated_kv_cache, key_length)
|
||||||
@ -420,11 +341,9 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)(
|
attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)(
|
||||||
query, key, value, attention_mask, head_mask
|
query, key, value, attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.multi_query:
|
|
||||||
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
|
||||||
attn_output = self.c_proj(attn_output)
|
attn_output = self.c_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
@ -432,9 +351,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
if output_attentions:
|
if output_attentions:
|
||||||
if flash_attention:
|
if flash_attention:
|
||||||
raise ValueError("`output_attentions` is not supported with Flash Attention.")
|
raise ValueError("`output_attentions` is not supported with Flash Attention.")
|
||||||
if self.multi_query:
|
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
||||||
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
attn_weights = attn_weights.transpose(1, 2)
|
||||||
attn_weights = attn_weights.transpose(1, 2)
|
|
||||||
outputs += (attn_weights,)
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
return outputs # a, present, (attentions)
|
return outputs # a, present, (attentions)
|
||||||
@ -478,7 +396,6 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
hidden_states: Optional[Tuple[torch.Tensor]],
|
hidden_states: Optional[Tuple[torch.Tensor]],
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
@ -495,7 +412,6 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@ -570,7 +486,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.multi_query = config.multi_query
|
assert config.multi_query
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
if config.add_cross_attention:
|
if config.add_cross_attention:
|
||||||
@ -624,7 +540,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
|
|
||||||
# MQA models: (batch_size, query_length, n_heads, key_length)
|
# MQA models: (batch_size, query_length, n_heads, key_length)
|
||||||
# MHA models: (batch_size, n_heads, query_length, key_length)
|
# MHA models: (batch_size, n_heads, query_length, key_length)
|
||||||
return attention_mask.unsqueeze(2 if self.multi_query else 1)
|
return attention_mask.unsqueeze(2)
|
||||||
|
|
||||||
def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device):
|
def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device):
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
@ -646,7 +562,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
@ -662,7 +577,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
assert attention_mask is not None
|
assert attention_mask is not None
|
||||||
assert token_type_ids is None
|
assert token_type_ids is None
|
||||||
assert position_ids is not None
|
assert position_ids is not None
|
||||||
assert head_mask is None
|
|
||||||
assert inputs_embeds is None
|
assert inputs_embeds is None
|
||||||
assert encoder_hidden_states is None
|
assert encoder_hidden_states is None
|
||||||
assert encoder_attention_mask is None
|
assert encoder_attention_mask is None
|
||||||
@ -729,12 +643,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
if encoder_hidden_states is not None or encoder_attention_mask is not None:
|
if encoder_hidden_states is not None or encoder_attention_mask is not None:
|
||||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.wte(input_ids)
|
inputs_embeds = self.wte(input_ids)
|
||||||
position_embeds = self.wpe(position_ids)
|
position_embeds = self.wpe(position_ids)
|
||||||
@ -776,14 +684,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask[i],
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@ -874,7 +780,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
@ -898,7 +803,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
@ -936,15 +840,3 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
cross_attentions=transformer_outputs.cross_attentions,
|
cross_attentions=transformer_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(
|
|
||||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
|
||||||
) -> Tuple[Tuple[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
|
||||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
|
||||||
beam_idx at every generation step.
|
|
||||||
"""
|
|
||||||
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user