Fixes, cpu-optimized model, misc

This commit is contained in:
Joel Lamy-Poirier 2023-05-26 15:22:20 -04:00
parent 72eefa3612
commit 7a94845eba
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
7 changed files with 1068 additions and 511 deletions

View File

@ -100,6 +100,8 @@ model_type_map={
"vector":VectorizedCausalLM, "vector":VectorizedCausalLM,
"bigcode":BigcodeCausalLM, "bigcode":BigcodeCausalLM,
"bigcode2":Bigcode2CausalLM, "bigcode2":Bigcode2CausalLM,
"bigcode3":Bigcode2CausalLM,
"bigcode4":Bigcode2CausalLM,
} }

View File

@ -508,6 +508,63 @@ class CausalLM(Model):
) )
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
def fast_forward(
self,
batch: CausalLMBatch,
max_input_length: int,
cache_dtype: Optional[torch.dtype],
):
diff = max_input_length - batch.max_input_length
batch.max_input_length+=diff
fill_value=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
for i in range(len(batch)):
batch.all_input_ids[i]=torch.nn.functional.pad(batch.all_input_ids[i], (0, 0, 0, diff), value=fill_value)[:batch.max_input_length]
batch.input_lengths[i]+=diff
batch.prefix_offsets[i] = 0
batch.read_offsets[i] = 0
batch.attention_mask[:, -batch.padding_right_offset:batch.padding_right_offset+diff] = 1
batch.padding_right_offset -= diff
if cache_dtype is None:
assert batch.input_ids.shape==(len(batch),batch.max_input_length-diff)
batch.input_ids=torch.nn.functional.pad(batch.input_ids, (0,diff), value=fill_value)[:, :batch.max_input_length]
batch.position_ids = batch.attention_mask[:, :batch.max_input_length].long().cumsum(-1) - 1
batch.past_key_values=None
else:
from transformers import GPTBigCodeForCausalLM
batch.input_ids=batch.input_ids[:,:1].fill_(fill_value)
batch.position_ids = batch.position_ids[:, -1:] + diff
if isinstance(self.model, GPTBigCodeForCausalLM):
batch.past_key_values=[
torch.randn(
[
len(batch),
batch.max_input_length - 1,
2 * self.model.config.n_embd // self.model.config.n_head,
],
dtype=cache_dtype,
device=batch.input_ids.device,
)
for _ in range(self.model.config.n_layer)
]
else:
batch.past_key_values=[
[torch.randn(
[
len(batch),
batch.max_input_length - 1,
self.model.config.n_embd // self.model.config.n_head,
],
dtype=cache_dtype,
device=batch.input_ids.device,
)
for _ in range(2)] for _ in range(self.model.config.n_layer)
]
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
@ -515,6 +572,9 @@ class CausalLM(Model):
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
#print("AAA", batch.max_input_length, batch.input_ids, batch.position_ids, batch.all_input_ids, batch.all_input_ids[0].shape)
#print("BBB", batch.max_input_length, batch.input_lengths, batch.padding_right_offset, batch.past_key_values is None)
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.input_ids,
attention_mask, attention_mask,

View File

@ -146,9 +146,10 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_past: torch.Tensor, layer_past: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
batch_size: int,
key_length: int, key_length: int,
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
batch_size=hidden_states.size(0)
query, key_value = self.c_attn.forward(hidden_states).split( query, key_value = self.c_attn.forward(hidden_states).split(
(self.embed_dim, 2 * self.head_dim), dim=-1 (self.embed_dim, 2 * self.head_dim), dim=-1
) )
@ -183,7 +184,6 @@ class GPTBigCodeAttention(nn.Module):
unscale = self.layer_idx + 1 if upcast else 1 unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale**-1 / self.head_dim**0.5 scale_factor = unscale**-1 / self.head_dim**0.5
# TODO: No need to unsqueeze?
hidden_states = torch.baddbmm( hidden_states = torch.baddbmm(
torch.empty( torch.empty(
(batch_size, self.num_heads, padded_key_length), (batch_size, self.num_heads, padded_key_length),
@ -194,7 +194,7 @@ class GPTBigCodeAttention(nn.Module):
key.transpose(-1, -2), key.transpose(-1, -2),
beta=0, beta=0,
alpha=scale_factor, alpha=scale_factor,
).unsqueeze_(1) )
if upcast: if upcast:
hidden_states = upcast_masked_softmax( hidden_states = upcast_masked_softmax(
@ -205,7 +205,7 @@ class GPTBigCodeAttention(nn.Module):
hidden_states, attention_mask, self.mask_value hidden_states, attention_mask, self.mask_value
) )
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape) hidden_states = torch.bmm(hidden_states, value).view(query.shape)
hidden_states = self.c_proj.forward(hidden_states) hidden_states = self.c_proj.forward(hidden_states)
@ -265,7 +265,6 @@ class GPTBigCodeBlock(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
layer_past: torch.Tensor, layer_past: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
batch_size: int,
key_length: int, key_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, Any]: ) -> Tuple[torch.Tensor, torch.Tensor, Any]:
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual) hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
@ -273,7 +272,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states, hidden_states,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
batch_size=batch_size,
key_length=key_length, key_length=key_length,
) )
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
@ -370,12 +368,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta" config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
) )
# Causal mask
self.causal_mask = torch.ones(
(config.max_position_embeddings, config.max_position_embeddings),
dtype=torch.bool,
device=device,
).tril_()
self.mask_value = torch.full( self.mask_value = torch.full(
(), torch.finfo(torch.float32).min, dtype=torch.float32, device=device (), torch.finfo(torch.float32).min, dtype=torch.float32, device=device
) )
@ -459,9 +451,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
key_length: int, key_length: int,
) -> Tuple: ) -> Tuple:
batch_size, query_length = input_ids.shape
assert query_length == 1
hidden_states = self.transformer.wte.forward( hidden_states = self.transformer.wte.forward(
input_ids input_ids
) + self.transformer.wpe.forward(position_ids) ) + self.transformer.wpe.forward(position_ids)
@ -469,13 +458,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
# Standardize shape to (batch_size, hidden_size) # Standardize shape to (batch_size, hidden_size)
hidden_states.squeeze_(1) hidden_states.squeeze_(1)
# Self-attention mask (padding + causal).
# TODO: Avoid unsqueeze
attention_mask = self.causal_mask[
None, key_length - 1 : key_length, : attention_mask.size(-1)
] * attention_mask.unsqueeze(1)
attention_mask.unsqueeze_(2)
residual = None residual = None
block: GPTBigCodeBlock block: GPTBigCodeBlock
for i, (block, layer_past) in enumerate( for i, (block, layer_past) in enumerate(
@ -486,11 +468,10 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
residual=residual, residual=residual,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
batch_size=batch_size,
key_length=key_length, key_length=key_length,
) )
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual) hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) hidden_states = self.lm_head.forward(hidden_states)
return hidden_states, past_key_values return hidden_states.unsqueeze_(1), past_key_values

View File

@ -13,12 +13,15 @@
# limitations under the License. # limitations under the License.
"""PyTorch GPTBigCode model.""" """PyTorch GPTBigCode model."""
import math import math
from typing import Optional, Tuple, Any, Union, List from typing import Optional, Tuple, Any, List
from enum import IntEnum
import torch from torch import where, addmm, mm, float32, dtype, Tensor, baddbmm, empty, device, bmm, full, ones, finfo, jit
import torch.utils.checkpoint from torch.nn import Linear, Embedding, Module, LayerNorm, ModuleList
from torch import nn from torch.nn.functional import gelu, softmax, embedding
from dropout_layer_norm import dropout_add_ln_fwd
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
@ -27,357 +30,307 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
) )
class InferenceRunnerType(IntEnum):
NO_RUNNER = 0
# Use the inference runner without cuda graphs.
BASE_RUNNER = 1
# Use cuda graphs in the inference runner. Leave out the attention which has a variable shape.
# This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models.
PARTIAL_GRAPH = 2
# Turn the whole model into a cuda graph. One graph for each sequence length.
# Note: only useful for small batches and models, graphs take some time to generate, flaky.
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
FULL_GRAPH = 3
try:
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
pad_input = None
unpad_input = None
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@torch.jit.script @jit.script
def upcast_masked_softmax( def upcast_masked_softmax(
x: torch.Tensor, x: Tensor, mask: Tensor, mask_value: Tensor, scale: float
mask: torch.Tensor,
mask_value: torch.Tensor,
scale: float,
softmax_dtype: torch.dtype,
): ):
input_dtype = x.dtype input_dtype = x.dtype
x = x.to(softmax_dtype) * scale x = x.to(float32) * scale
x = torch.where(mask, x, mask_value) x = where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) x = softmax(x, dim=-1).to(input_dtype)
return x return x
@torch.jit.script class GPTBigCodeAttention(Module):
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): mask_value: Tensor
input_dtype = x.dtype
x = x.to(softmax_dtype) * scale
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: dtype):
@torch.jit.script
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1)
return x
@torch.profiler.record_function("softmax_function")
def softmax_function(
x: torch.Tensor,
mask: torch.Tensor,
mask_value: torch.Tensor,
scale: float,
softmax_dtype: torch.dtype,
upcast: bool = True,
):
"""
This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case
needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting
and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely
inefficient.
"""
# assert x.size(-1) % 8 == 0
if upcast:
if mask is None:
return upcast_softmax(x, scale, softmax_dtype)
else:
return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
else:
if mask is None:
return torch.nn.functional.softmax(x, dim=-1)
else:
return masked_softmax(x, mask, mask_value)
class GPTBigCodeAttention(nn.Module):
def __init__(
self,
config,
layer_idx=None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__() super().__init__()
self.mask_value = None
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.layer_idx = layer_idx self.layer_idx = layer_idx
# KV caching and padding self.c_attn = Linear(
self.c_attn = nn.Linear(
self.embed_dim, self.embed_dim,
self.embed_dim + 2 * self.head_dim, self.embed_dim + 2 * self.head_dim,
dtype=dtype, dtype=dtype,
device=device, device="meta",
) )
self.c_proj = nn.Linear( self.c_proj = Linear(
self.embed_dim, self.embed_dim, dtype=dtype, device=device self.embed_dim, self.embed_dim, dtype=dtype, device="meta"
) )
@torch.profiler.record_function("GPTBigCodeAttention._get_mask_value") class GPTBigCodeMLP(Module):
def _get_mask_value(self, device, dtype): # TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
# torch.where expects a tensor. We use a cache to avoid recreating it every time. def __init__(self, config: GPTBigCodeConfig, dtype: dtype):
if ( super().__init__()
self.mask_value is None embed_dim = config.hidden_size
or self.mask_value.dtype != dtype inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
or self.mask_value.device != device self.c_fc = Linear(embed_dim, inner_dim, dtype=dtype, device="meta")
): self.c_proj = Linear(inner_dim, embed_dim, dtype=dtype, device="meta")
self.mask_value = torch.full(
[], torch.finfo(dtype).min, dtype=dtype, device=device
class GPTBigCodeBlock(Module):
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: dtype):
super().__init__()
self.ln_1 = LayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device="meta",
)
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
self.ln_2 = LayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device="meta",
)
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
def post_load_weights(self, mask_value):
self.attn.mask_value = mask_value
self.mask_value = mask_value
self.hd=self.attn.head_dim
self.split0=(self.attn.embed_dim, 2*self.hd)
self.split1=(self.hd, self.hd)
self.aaw=self.attn.c_attn.weight.t_()
self.aab=self.attn.c_attn.bias
self.apw=self.attn.c_proj.weight.t_()
self.apb=self.attn.c_proj.bias
self.unscale=self.attn.layer_idx + 1
self.ps=self.hd**-0.5
self.ds=self.unscale**-1 * self.ps
self.l1w=self.ln_1.weight
self.l1b=self.ln_1.bias
self.e1=self.ln_1.eps
self.l2w=self.ln_2.weight
self.l2b=self.ln_2.bias
self.e2=self.ln_2.eps
self.mfb=self.mlp.c_fc.bias
self.mfw=self.mlp.c_fc.weight.t_()
self.mfb=self.mlp.c_fc.bias
self.mpw=self.mlp.c_proj.weight.t_()
self.mpb=self.mlp.c_proj.bias
def prefill(
self,
hidden_states: Tensor,
residual: Optional[Tensor],
sequence_lengths,
key_length: int,
) -> Tuple[Tensor, Tensor, Any]:
if residual is None: # First layer
residual = hidden_states
hidden_states, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l1w,
self.l1b,
None,
None,
None,
None,
0.0,
self.e1,
1.0,
0,
None,
False,
False,
) )
return self.mask_value
@torch.profiler.record_function("GPTBigCodeAttention._attn")
def _attn(self, query, key, value, attention_mask):
softmax_dtype = torch.float32
upcast = query.dtype != softmax_dtype
unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale**-1 / self.head_dim**0.5
# (batch_size, query_length, num_heads * head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key_length = key.size(-2)
key = key.transpose(-1, -2)
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
if query.device.type == "cpu":
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
# but the fix has not been released as of pytorch version 2.0.0.
attn_weights.zero_()
beta = 1
else: else:
beta = 0 hidden_states, residual, *_ = dropout_add_ln_fwd(
attn_weights = torch.baddbmm( hidden_states,
attn_weights, query, key, beta=beta, alpha=scale_factor residual,
).view(attn_shape) self.l1w,
self.l1b,
attn_weights = softmax_function( None,
attn_weights, None,
attention_mask, None,
None None,
if attention_mask is None 0.0,
else self._get_mask_value(attn_weights.device, softmax_dtype), self.e1,
unscale, 1.0,
softmax_dtype, 0,
upcast, None,
False,
False,
)
hidden_shape = hidden_states.shape
query, key_value = addmm(self.aab, hidden_states, self.aaw).split(self.split0, dim=-1)
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
key, value = (
key_value.unsqueeze(1)
.expand(hidden_shape[0], self.num_heads, 2 * self.head_dim)
.split(self.split1, dim=-1)
) )
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
return attn_output
@torch.profiler.record_function("GPTBigCodeAttention._attn_flash")
def _attn_flash(self, query, key, value, flash_params):
query_shape = query.shape
attn_shape = query_shape[0], self.num_heads, self.head_dim
query = query.view(attn_shape)
key = key.unsqueeze(1).expand(attn_shape)
value = value.unsqueeze(1).expand(attn_shape)
sequence_lengths, padding_index, _, max_sequence_length = flash_params
# attn_output: (sum_seq_len, num_heads * head_dim) # attn_output: (sum_seq_len, num_heads * head_dim)
attn_output = flash_attn_unpadded_func( hidden_states = flash_attn_unpadded_func(
query, query,
key, key,
value, value,
sequence_lengths, sequence_lengths,
sequence_lengths, sequence_lengths,
max_sequence_length, key_length,
max_sequence_length, key_length,
0.0, 0.0,
softmax_scale=self.head_dim**-0.5, softmax_scale=self.ps,
causal=True, causal=True,
).view(query_shape) ).view(hidden_shape)
hidden_states = addmm(self.apb, hidden_states, self.apw, out=query)
return attn_output hidden_states, residual, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l2w,
self.l2b,
None,
None,
None,
None,
0.0,
self.e2,
1.0,
0,
None,
False,
False,
)
# TODO: Find an inplace and/or fused (with addmm) gelu kernel?
hidden_states = addmm(self.mpb, gelu(addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states)
return hidden_states, residual, key_value
@torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches") def decode(
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): self,
# Convert to standard KV cache format. hidden_states: Tensor,
if flash_params is not None: residual: Optional[Tensor],
_, padding_index, batch_size, max_sequence_length = flash_params layer_past: Tensor,
current_kv_cache = pad_input( attention_mask: Tensor,
key_value, padding_index, batch_size, max_sequence_length key_length: int,
) -> Tuple[Tensor, Tensor, Any]:
batch_size=hidden_states.size(0)
if residual is None: # First layer
residual = hidden_states
hidden_states, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l1w,
self.l1b,
None,
None,
None,
None,
0.0,
self.e1,
1.0,
0,
None,
False,
False,
)
else:
hidden_states, residual, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l1w,
self.l1b,
None,
None,
None,
None,
0.0,
self.e1,
1.0,
0,
None,
False,
False,
) )
return key_value, (current_kv_cache, max_sequence_length)
current_kv_cache = key_value query, key_value = addmm(self.aab, hidden_states, self.aaw).split(self.split0, dim=-1)
query_view = query.view(batch_size, self.num_heads, self.head_dim)
# Calculate dimensions and recover layer_past # Calculate dimensions and recover layer_past
batch_size = current_kv_cache.size(0)
query_length = current_kv_cache.size(-2)
if layer_past is None:
allocated_kv_cache, last_key_length = None, 0
last_kv_cache = None
key_length = query_length
allocated_key_length = key_length
else:
allocated_kv_cache, last_key_length = layer_past
last_kv_cache = allocated_kv_cache[:, :last_key_length]
key_length = query_length + last_key_length
allocated_key_length = allocated_kv_cache.size(-2)
padded_key_length = attention_mask.size(-1) padded_key_length = attention_mask.size(-1)
allocated_key_length = layer_past.size(-2)
# Re-allocate kv cache and copy last value # TODO: Allow pre-allocation with size > padded_key_length
if padded_key_length > allocated_key_length: if padded_key_length > allocated_key_length:
allocated_kv_cache = torch.empty( # Re-allocate kv cache and copy last value
[batch_size, padded_key_length, 2 * self.head_dim], allocated_kv_cache = empty(
dtype=current_kv_cache.dtype, [batch_size, padded_key_length, 2*self.hd],
device=current_kv_cache.device, dtype=key_value.dtype,
device=key_value.device,
) )
if layer_past is not None: allocated_kv_cache[:, : key_length - 1].copy_(
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) layer_past[:, : key_length - 1]
if padded_key_length > key_length: )
# Nans in `value` can propagate through the matrix multiplication, # Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.) # so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, key_length:, self.head_dim :].zero_() allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
layer_past = allocated_kv_cache
# Copy the new values. # Copy the new values.
if padded_key_length > allocated_key_length or layer_past is not None: layer_past[:, key_length - 1].copy_(key_value)
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
# Use the merged KV cache.
# Not needed when layer_past is None but frees some memory.
key_value = padded_kv_cache
if allocated_kv_cache is None: key, value = layer_past.split(self.split1, dim=-1)
allocated_kv_cache = current_kv_cache
present = allocated_kv_cache, key_length
return key_value, present
@torch.profiler.record_function("GPTBigCodeAttention.forward") # Assume we always upcast (optimized for fp16/bf16)
def forward( # TODO: Upcasting needed for bf16?
self, hidden_states = baddbmm(
hidden_states: torch.Tensor, empty(
layer_past: Optional[torch.Tensor] = None, (batch_size, self.num_heads, padded_key_length),
attention_mask: Optional[torch.Tensor] = None, device=query.device,
flash_params: Optional[Tuple] = None, dtype=query.dtype,
) -> Tuple[torch.Tensor, Any]: ),
query, key_value = self.c_attn(hidden_states).split( query_view,
(self.embed_dim, 2 * self.head_dim), dim=-1 key.transpose(-1, -2),
beta=0,
alpha=self.ds,
)
hidden_states = upcast_masked_softmax(
hidden_states, attention_mask, self.mask_value, self.unscale
) )
# present = (allocated_kv_cache, key_length) # TODO: Write attn output directly into query, avoids both allocation and view.
key_value, present = self._merge_kv_caches( bmm(hidden_states.squeeze_(1), value, out=query_view)
key_value, layer_past, attention_mask, flash_params # TODO: Reuse attn weight tensor for c_proj output?
hidden_states = addmm(self.apb, query, self.apw)
hidden_states, residual, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l2w,
self.l2b,
None,
None,
None,
None,
0.0,
self.e2,
1.0,
0,
None,
False,
False,
) )
# TODO: Reuse attn weight tensor for c_fc output? (ok if padded_key_length>=4*head_dim, otherwise need to allocate a bigger one).
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) # TODO: Find an inplace and/or fused (with addmm) gelu kernel?
hidden_states = addmm(self.mpb, gelu(addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states)
if flash_params is None: return hidden_states, residual, layer_past
attn_output = self._attn(query, key, value, attention_mask)
else:
attn_output = self._attn_flash(query, key, value, flash_params)
attn_output = self.c_proj(attn_output)
return attn_output, present
class GPTBigCodeMLP(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
self.c_fc = nn.Linear(embed_dim, inner_dim)
self.c_proj = nn.Linear(inner_dim, embed_dim)
@torch.profiler.record_function("GPTBigCodeMLP.forward")
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
return self.c_proj(
nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")
)
class GPTBigCodeBlock(nn.Module):
def __init__(
self,
config,
layer_idx=None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.ln_1 = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
self.attn = GPTBigCodeAttention(
config, layer_idx=layer_idx, dtype=dtype, device=device
)
self.ln_2 = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
self.mlp = GPTBigCodeMLP(config)
@torch.profiler.record_function("GPTBigCodeBlock.forward")
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
flash_params: Optional[Tuple] = None,
) -> Tuple[torch.Tensor, Any]:
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
ai = self.ln_1(hidden_states)
attn_output, present = self.attn(
ai,
layer_past=layer_past,
attention_mask=attention_mask,
flash_params=flash_params,
)
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
hidden_states.add_(attn_output)
with torch.profiler.record_function("GPTBigCodeAttention.dummy"):
pass
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
ai = self.ln_2(hidden_states)
ai = self.mlp(ai)
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
hidden_states.add_(ai)
return hidden_states, present
class GPTBigCodePreTrainedModel(PreTrainedModel): class GPTBigCodePreTrainedModel(PreTrainedModel):
@ -396,9 +349,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, GPTBigCodeModel): if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
module.bias.fill_(True).tril_()
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
@ -412,215 +363,188 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
), ),
) )
module.c_proj._is_hf_initialized = True module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear): elif isinstance(module, Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding): elif isinstance(module, Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
class GPTBigCodeModel(GPTBigCodePreTrainedModel): class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def __init__( # TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict)
self, def __init__(self, config: GPTBigCodeConfig, dtype: dtype):
config,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__(config) super().__init__(config)
self.wte = nn.Embedding( self.wte = Embedding(
config.vocab_size, config.hidden_size, dtype=dtype, device=device config.vocab_size, config.hidden_size, dtype=dtype, device="meta"
) )
self.wpe = nn.Embedding( self.wpe = Embedding(
config.max_position_embeddings, config.max_position_embeddings,
config.hidden_size, config.hidden_size,
dtype=dtype, dtype=dtype,
device=device, device="meta",
) )
self.h = nn.ModuleList( self.h = ModuleList(
[ [
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) GPTBigCodeBlock(config, layer_idx=i, dtype=dtype)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
self.ln_f = nn.LayerNorm( self.ln_f = LayerNorm(
config.hidden_size, config.hidden_size,
dtype=dtype, dtype=dtype,
device=device, device="meta",
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.inference_runner_type = (
InferenceRunnerType.NO_RUNNER
) # InferenceRunnerType(config.inference_runner)
self.flash_attention = True # config.flash_attention
if self.flash_attention:
if flash_attn_unpadded_func is None:
raise RuntimeError(
"Flash Attention requires `flash_attn` and `einops`. "
"To install, run `pip install flash-attn einops`."
)
if self.inference_runner_type == InferenceRunnerType.NO_RUNNER:
self.inference_runner = None
else:
from .inference_runner import GPTBigCodeInferenceRunner
self.inference_runner = GPTBigCodeInferenceRunner(config, self)
# Causal mask
self.register_buffer(
"bias",
torch.empty(
(config.max_position_embeddings, config.max_position_embeddings),
dtype=torch.bool,
device=device,
),
)
# @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask")
def _get_causal_mask(self, padding_mask, query_length, key_length):
# Self-attention mask.
attention_mask = self.bias[
None, key_length - query_length : key_length, :key_length
]
if padding_mask is not None:
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
dtype=torch.bool, device=attention_mask.device
)
pad = -key_length % 8
if pad > 0:
attention_mask = torch.nn.functional.pad(
attention_mask, (0, pad), mode="constant", value=False
)
# (batch_size, query_length, n_heads, key_length)
return attention_mask.unsqueeze(2)
# @torch.profiler.record_function("GPTBigCodeModel.forward")
def forward(
self,
*,
input_ids: torch.Tensor,
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: torch.Tensor,
) -> Tuple:
if self.inference_runner is not None and past_key_values is not None:
if self.config.validate_runner_input:
assert past_key_values is not None
return self.inference_runner.forward(
input_ids, attention_mask, position_ids, past_key_values
)
batch_size, query_length = input_ids.shape
flash_attention = self.flash_attention and past_key_values is None
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][1]
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
# TODO: Unpad earlier (input ids), support unpadded input?
if flash_attention:
(
hidden_states,
padding_index,
sequence_lengths,
max_sequence_length,
) = unpad_input(hidden_states, attention_mask)
flash_params = (
sequence_lengths,
padding_index,
batch_size,
max_sequence_length,
)
attention_mask = None
else:
key_length = past_length + query_length
# Self-attention mask (padding + causal).
attention_mask = self._get_causal_mask(
attention_mask, query_length, key_length
)
flash_params = None
presents = []
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
flash_params=flash_params,
)
hidden_states = outputs[0]
presents.append(outputs[1])
hidden_states = self.ln_f(hidden_states)
if flash_attention:
hidden_states = pad_input(
hidden_states, padding_index, batch_size, query_length
)
return hidden_states, presents
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
def __init__( def __init__(
self, self, config, dtype: dtype, device: device = device("cuda")
config,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
): ):
super().__init__(config) super().__init__(config)
meta = torch.device("meta") if device.type != "cuda":
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) raise NotImplementedError(f"Device {device} not supported")
self.lm_head = nn.Linear(
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta self.transformer = GPTBigCodeModel(config, dtype=dtype)
self.lm_head = Linear(
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
)
self.mask_value = full(
(), finfo(float32).min, dtype=float32, device=device
) )
self.to_empty(device=device) self.to_empty(device=device)
# Initialize weights and apply final processing # Initialize weights and apply final processing
# TODO: Skip?
self.post_init() self.post_init()
# @torch.profiler.record_function("GPTBigCodeForCausalLM.forward") def post_load_weights(self):
def forward( layer: GPTBigCodeBlock
for layer in self.transformer.h:
layer.post_load_weights(self.mask_value)
self.tw=self.transformer.wte.weight
self.pw=self.transformer.wpe.weight
self.hw=self.lm_head.weight.t_()
self.lw=self.transformer.ln_f.weight
self.lb=self.transformer.ln_f.bias
self.le=self.transformer.ln_f.eps
def prefill(
self, self,
*, *,
input_ids: torch.Tensor, input_ids: Tensor,
past_key_values: Optional[Union[List[torch.Tensor], int]] = None, attention_mask: Tensor = None,
attention_mask: Optional[torch.Tensor] = None, position_ids: Tensor,
position_ids: torch.Tensor,
predict_all_tokens: bool = True, predict_all_tokens: bool = True,
) -> Tuple: ) -> Tuple:
hidden_states, presents = self.transformer( batch_size, query_length = input_ids.shape
input_ids=input_ids,
past_key_values=past_key_values, hidden_states = embedding(input_ids, self.tw).add_(embedding(position_ids, self.pw))
attention_mask=attention_mask,
position_ids=position_ids, # Prefill (flash attn)
# TODO: Unpad earlier (input ids)?
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(
hidden_states, attention_mask
) )
# with torch.profiler.record_function("GPTBigCodeForCausalLM.head"): assert key_length == query_length
if not predict_all_tokens:
# We only care about the last token.
hidden_states = hidden_states[:, -1:]
lm_logits = self.lm_head(hidden_states) residual = None
past_key_values = []
block: GPTBigCodeBlock
for block in self.transformer.h:
hidden_states, residual, key_value = block.prefill(
hidden_states,
residual,
sequence_lengths,
key_length,
)
past_key_values.append(
pad_input(key_value, padding_index, batch_size, query_length)
)
return lm_logits, presents hidden_states, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.lw,
self.lb,
None,
None,
None,
None,
0.0,
self.le,
1.0,
0,
None,
False,
False,
)
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
del residual
if predict_all_tokens:
hidden_states = pad_input(
mm(hidden_states, self.hw), padding_index, batch_size, query_length
)
else:
# TODO: Index directly using cu_seqlens instead
hidden_states = mm(pad_input(
hidden_states, padding_index, batch_size, query_length
)[:, -1], self.hw).unsqueeze_(1)
return hidden_states, past_key_values
def decode(
self,
*,
input_ids: Tensor,
past_key_values: List[Tensor],
attention_mask: [Tensor],
position_ids: Tensor,
key_length: int,
) -> Tuple:
hidden_states = embedding(input_ids, self.tw).add_(embedding(position_ids, self.pw))
residual = None
block: GPTBigCodeBlock
for i, (block, layer_past) in enumerate(
zip(self.transformer.h, past_key_values)
):
hidden_states, residual, past_key_values[i] = block.decode(
hidden_states,
residual,
layer_past,
attention_mask,
key_length,
)
hidden_states, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.lw,
self.lb,
None,
None,
None,
None,
0.0,
self.le,
1.0,
0,
None,
False,
False,
)
# TODO: Reuse residual?
return mm(hidden_states, self.hw).unsqueeze_(1), past_key_values

View File

@ -0,0 +1,565 @@
# coding=utf-8
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch GPTBigCode model."""
import math
from typing import Optional, Tuple, Any, List
import torch
import torch.utils.checkpoint
from torch import nn
from dropout_layer_norm import dropout_add_ln_fwd
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
GPTBigCodeConfig,
)
class FastLayerNorm(nn.LayerNorm):
# TODO: Validate dimension
def forward(self, hidden_states, residual=None):
out, residual, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return out, residual
class FastLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.addmm(self.bias, input, self.weight)
class FastLinearNoBias(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.mm(input, self.weight)
logger = logging.get_logger(__name__)
@torch.jit.script
def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float
):
input_dtype = x.dtype
x = x.to(torch.float32) * scale
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1)
return x
class GPTBigCodeAttention(nn.Module):
mask_value: torch.Tensor
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.layer_idx = layer_idx
self.c_attn = FastLinear(
self.embed_dim,
self.embed_dim + 2 * self.head_dim,
dtype=dtype,
device="meta",
)
self.c_proj = FastLinear(
self.embed_dim, self.embed_dim, dtype=dtype, device="meta"
)
@torch.profiler.record_function("GPTBigCodeAttention.prefill")
def prefill(
self,
hidden_states: torch.Tensor,
sequence_lengths,
key_length: int,
) -> Tuple[torch.Tensor, Any]:
hidden_shape = hidden_states.shape
query, key_value = self.c_attn.forward(hidden_states).split(
(self.embed_dim, 2 * self.head_dim), dim=-1
)
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
key, value = (
key_value.unsqueeze(1)
.expand(hidden_shape[0], self.num_heads, 2 * self.head_dim)
.split((self.head_dim, self.head_dim), dim=-1)
)
# attn_output: (sum_seq_len, num_heads * head_dim)
hidden_states = flash_attn_unpadded_func(
query,
key,
value,
sequence_lengths,
sequence_lengths,
key_length,
key_length,
0.0,
softmax_scale=self.head_dim**-0.5,
causal=True,
).view(hidden_shape)
hidden_states = self.c_proj.forward(hidden_states)
return hidden_states, key_value
@torch.profiler.record_function("GPTBigCodeAttention.decode")
def decode(
self,
hidden_states: torch.Tensor,
layer_past: torch.Tensor,
attention_mask: torch.Tensor,
batch_size: int,
key_length: int,
) -> Tuple[torch.Tensor, Any]:
query, key_value = self.c_attn.forward(hidden_states).split(
(self.embed_dim, 2 * self.head_dim), dim=-1
)
# Calculate dimensions and recover layer_past
padded_key_length = attention_mask.size(-1)
allocated_key_length = layer_past.size(-2)
# TODO: Allow pre-allocation with size > padded_key_length
if padded_key_length > allocated_key_length:
# Re-allocate kv cache and copy last value
allocated_kv_cache = torch.empty(
[batch_size, padded_key_length, 2 * self.head_dim],
dtype=key_value.dtype,
device=key_value.device,
)
allocated_kv_cache[:, : key_length - 1].copy_(
layer_past[:, : key_length - 1]
)
# Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
layer_past = allocated_kv_cache
# Copy the new values.
layer_past[:, key_length - 1].copy_(key_value)
key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
# TODO: Upcasting needed for bf16?
upcast = query.dtype != torch.float32
unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale**-1 / self.head_dim**0.5
# TODO: No need to unsqueeze?
hidden_states = torch.baddbmm(
torch.empty(
(batch_size, self.num_heads, padded_key_length),
device=query.device,
dtype=query.dtype,
),
query.view(batch_size, self.num_heads, self.head_dim),
key.transpose(-1, -2),
beta=0,
alpha=scale_factor,
).unsqueeze_(1)
if upcast:
hidden_states = upcast_masked_softmax(
hidden_states, attention_mask, self.mask_value, unscale
)
else:
hidden_states = masked_softmax(
hidden_states, attention_mask, self.mask_value
)
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
hidden_states = self.c_proj.forward(hidden_states)
return hidden_states, layer_past
class GPTBigCodeMLP(nn.Module):
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
super().__init__()
embed_dim = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta")
self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta")
class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
super().__init__()
self.ln_1 = FastLayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device="meta",
)
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
self.ln_2 = FastLayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device="meta",
)
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
def post_load_weights(self, mask_value):
self.attn.mask_value = mask_value
self.attn.c_attn.weight.t_()
self.attn.c_proj.weight.t_()
self.l1w=self.ln_1.weight
self.l1b=self.ln_1.bias
self.e1=self.ln_1.eps
self.l2w=self.ln_2.weight
self.l2b=self.ln_2.bias
self.e2=self.ln_2.eps
self.mfb=self.mlp.c_fc.bias
self.mfw=self.mlp.c_fc.weight.t_()
self.mfb=self.mlp.c_fc.bias
self.mpw=self.mlp.c_proj.weight.t_()
self.mpb=self.mlp.c_proj.bias
@torch.profiler.record_function("GPTBigCodeBlock.prefill")
def prefill(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
sequence_lengths,
key_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
hidden_states, present = self.attn.prefill(
hidden_states,
sequence_lengths,
key_length,
)
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
hidden_states = self.mlp.c_proj.forward(
nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")
)
return hidden_states, residual, present
@torch.profiler.record_function("GPTBigCodeBlock.decode")
def decode(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
layer_past: torch.Tensor,
attention_mask: torch.Tensor,
batch_size: int,
key_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
if residual is None:
residual = hidden_states
hidden_states, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l1w,
self.l1b,
None,
None,
None,
None,
0.0,
self.e1,
1.0,
0,
None,
False,
False,
)
else:
hidden_states, residual, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l1w,
self.l1b,
None,
None,
None,
None,
0.0,
self.e1,
1.0,
0,
None,
False,
False,
)
hidden_states, present = self.attn.decode(
hidden_states,
layer_past,
attention_mask,
batch_size,
key_length,
)
hidden_states, residual, *_ = dropout_add_ln_fwd(
hidden_states,
residual,
self.l2w,
self.l2b,
None,
None,
None,
None,
0.0,
self.e2,
1.0,
0,
None,
False,
False,
)
hidden_states = torch.addmm(self.mpb, nn.functional.gelu(torch.addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states)
return hidden_states, residual, present
class GPTBigCodePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GPTBigCodeConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = False
_no_split_modules = ["GPTBigCodeBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
module.c_proj.weight.data.normal_(
mean=0.0,
std=(
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
),
)
module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict)
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
super().__init__(config)
self.wte = nn.Embedding(
config.vocab_size, config.hidden_size, dtype=dtype, device="meta"
)
self.wpe = nn.Embedding(
config.max_position_embeddings,
config.hidden_size,
dtype=dtype,
device="meta",
)
self.h = nn.ModuleList(
[
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm(
config.hidden_size,
dtype=dtype,
device="meta",
eps=config.layer_norm_epsilon,
)
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
def __init__(
self, config, dtype: torch.dtype, device: torch.device = torch.device("cuda")
):
super().__init__(config)
if device.type != "cuda":
raise NotImplementedError(f"Device {device} not supported")
self.transformer = GPTBigCodeModel(config, dtype=dtype)
self.lm_head = FastLinearNoBias(
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
)
# Causal mask
self.causal_mask = torch.ones(
(config.max_position_embeddings, config.max_position_embeddings),
dtype=torch.bool,
device=device,
).tril_()
self.mask_value = torch.full(
(), torch.finfo(torch.float32).min, dtype=torch.float32, device=device
)
self.to_empty(device=device)
# Initialize weights and apply final processing
# TODO: Skip?
self.post_init()
def prefill(
self,
*,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
position_ids: torch.Tensor,
predict_all_tokens: bool = True,
) -> Tuple:
batch_size, query_length = input_ids.shape
hidden_states = self.transformer.wte.forward(
input_ids
) + self.transformer.wpe.forward(position_ids)
# Prefill (flash attn)
# TODO: Unpad earlier (input ids)?
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(
hidden_states, attention_mask
)
assert key_length == query_length
residual = None
past_key_values = []
block: GPTBigCodeBlock
for block in self.transformer.h:
hidden_states, residual, key_value = block.prefill(
hidden_states,
residual=residual,
sequence_lengths=sequence_lengths,
key_length=query_length,
)
past_key_values.append(
pad_input(key_value, padding_index, batch_size, query_length)
)
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
del residual
if predict_all_tokens:
hidden_states = self.lm_head.forward(hidden_states)
hidden_states = pad_input(
hidden_states, padding_index, batch_size, query_length
)
else:
# TODO: Index directly instead
hidden_states = pad_input(
hidden_states, padding_index, batch_size, query_length
)[:, -1]
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
return hidden_states, past_key_values
def post_load_weights(self):
layer: GPTBigCodeBlock
for layer in self.transformer.h:
layer.post_load_weights(self.mask_value)
self.lm_head.weight.t_()
def decode(
self,
*,
input_ids: torch.Tensor,
past_key_values: List[torch.Tensor],
attention_mask: [torch.Tensor],
position_ids: torch.Tensor,
key_length: int,
) -> Tuple:
batch_size, query_length = input_ids.shape
assert query_length == 1
hidden_states = self.transformer.wte.forward(
input_ids
) + self.transformer.wpe.forward(position_ids)
# Standardize shape to (batch_size, hidden_size)
hidden_states.squeeze_(1)
# Self-attention mask (padding + causal).
# TODO: Avoid unsqueeze
attention_mask = self.causal_mask[
None, key_length - 1 : key_length, : attention_mask.size(-1)
] * attention_mask.unsqueeze(1)
attention_mask.unsqueeze_(2)
residual = None
block: GPTBigCodeBlock
for i, (block, layer_past) in enumerate(
zip(self.transformer.h, past_key_values)
):
hidden_states, residual, past_key_values[i] = block.decode(
hidden_states,
residual=residual,
layer_past=layer_past,
attention_mask=attention_mask,
batch_size=batch_size,
key_length=key_length,
)
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
return hidden_states, past_key_values

View File

@ -451,11 +451,17 @@ class FlashCausalLM(Model):
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]): def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]):
diff = max_input_length - max(batch.input_lengths) diff = max_input_length - max(batch.input_lengths)
fill_value=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
for i in range(len(batch)): for i in range(len(batch)):
batch.input_lengths[i] += diff batch.input_lengths[i] += diff
batch.prefix_offsets[i] = 0 for _ in range(diff):
batch.read_offsets[i] = 0 batch.all_input_ids[i].append(fill_value)
batch.all_input_ids[i] = batch.all_input_ids[i] + [self.tokenizer.pad_token_id] * diff if diff >= 0 else batch.all_input_ids[i][:diff] # For some reason just resetting the offsets makes things way slower.
_, batch.prefix_offsets[i], batch.read_offsets[i] = self.decode_token(
batch.all_input_ids[i],
batch.prefix_offsets[i],
batch.read_offsets[i],
)
# TODO: Bug!?! # TODO: Bug!?!
batch.stopping_criterias[i].current_tokens += diff batch.stopping_criterias[i].current_tokens += diff
@ -473,19 +479,19 @@ class FlashCausalLM(Model):
batch.past_key_values=None batch.past_key_values=None
else: else:
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first" assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
batch.input_ids.fill_(self.tokenizer.pad_token_id) batch.input_ids.fill_(fill_value)
batch.position_ids += diff batch.position_ids += diff
batch.cu_seqlens += diff * batch.cu_seqlens_q batch.cu_seqlens += diff * batch.cu_seqlens_q
# TODO: Bug!?! # TODO: Bug!?!
batch.max_seqlen += batch.max_seqlen + diff*len(batch) batch.max_seqlen += diff*len(batch)
for i in range(len(batch)): for i in range(len(batch)):
batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(self.tokenizer.pad_token_id) batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(fill_value)
batch.past_key_values = batch.past_key_values = torch.randn( batch.past_key_values = batch.past_key_values = torch.randn(
( (
batch.past_key_values.shape[0], batch.past_key_values.shape[0],
batch.past_key_values.shape[1] + len(batch.requests), batch.past_key_values.shape[1] + diff*len(batch.requests),
*batch.past_key_values.shape[2:], *batch.past_key_values.shape[2:],
), device=batch.past_key_values.device, dtype= cache_dtype ), device=batch.past_key_values.device, dtype= cache_dtype
) )

View File

@ -13,8 +13,15 @@ from text_generation_server.models.vectorized_causal_lm import (
VectorizedCausalLMBatch, VectorizedCausalLMBatch,
) )
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import ( from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import (
GPTBigCodeForCausalLM, GPTBigCodeForCausalLM as GPTBigCode2ForCausalLM,
) )
from text_generation_server.models.custom_modeling.gpt_bigcode3_modeling import (
GPTBigCodeForCausalLM as GPTBigCode3ForCausalLM,
)
from text_generation_server.models.custom_modeling.gpt_bigcode4_modeling import (
GPTBigCodeForCausalLM as GPTBigCode4ForCausalLM,
)
from transformers.modeling_utils import PreTrainedModel
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -96,8 +103,9 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
return len(self.requests) return len(self.requests)
class Bigcode2CausalLM(VectorizedCausalLM): class Bigcode2CausalLMBase(VectorizedCausalLM):
model: GPTBigCodeForCausalLM #model: GPTBigCode2ForCausalLM
_model_class:Type[PreTrainedModel]
def __init__( def __init__(
self, self,
@ -118,7 +126,7 @@ class Bigcode2CausalLM(VectorizedCausalLM):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
model = GPTBigCodeForCausalLM.from_pretrained( model = self._model_class.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
@ -163,18 +171,18 @@ class Bigcode2CausalLM(VectorizedCausalLM):
padded_key_length = ( padded_key_length = (
key_length + -key_length % batch.pad_key_length_to_multiple key_length + -key_length % batch.pad_key_length_to_multiple
) )
input_ids = batch.input_ids[:, key_length - 1 : key_length] input_ids = batch.input_ids[:, key_length - 1]
# Model Forward # Model Forward
logits, batch.past_key_values = self.model.decode( logits, batch.past_key_values = self.model.decode(
input_ids=input_ids, input_ids=input_ids,
attention_mask=batch.attention_mask[:, :padded_key_length], attention_mask=batch.attention_mask[:, None, :padded_key_length],
position_ids=batch.position_ids[:, key_length - 1 : key_length], position_ids=batch.position_ids[:, key_length - 1],
past_key_values=batch.past_key_values, past_key_values=batch.past_key_values,
key_length=key_length, key_length=key_length,
) )
next_token_ids, logprobs = batch.next_token_chooser( next_token_ids, logprobs = batch.next_token_chooser(
input_ids, logits, batch.details input_ids.unsqueeze(1), logits.unsqueeze(1), batch.details
) )
# Update batch # Update batch
# TODO: Why do we need all input ids? # TODO: Why do we need all input ids?
@ -201,3 +209,14 @@ class Bigcode2CausalLM(VectorizedCausalLM):
) )
for _ in range(self.model.config.n_layer) for _ in range(self.model.config.n_layer)
] ]
class Bigcode2CausalLM(Bigcode2CausalLMBase):
_model_class=GPTBigCode2ForCausalLM
class Bigcode3CausalLM(Bigcode2CausalLMBase):
_model_class=GPTBigCode3ForCausalLM
class Bigcode4CausalLM(Bigcode2CausalLMBase):
_model_class = GPTBigCode4ForCausalLM