mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fixes, cpu-optimized model, misc
This commit is contained in:
parent
72eefa3612
commit
7a94845eba
@ -100,6 +100,8 @@ model_type_map={
|
|||||||
"vector":VectorizedCausalLM,
|
"vector":VectorizedCausalLM,
|
||||||
"bigcode":BigcodeCausalLM,
|
"bigcode":BigcodeCausalLM,
|
||||||
"bigcode2":Bigcode2CausalLM,
|
"bigcode2":Bigcode2CausalLM,
|
||||||
|
"bigcode3":Bigcode2CausalLM,
|
||||||
|
"bigcode4":Bigcode2CausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -508,6 +508,63 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
def fast_forward(
|
||||||
|
self,
|
||||||
|
batch: CausalLMBatch,
|
||||||
|
max_input_length: int,
|
||||||
|
cache_dtype: Optional[torch.dtype],
|
||||||
|
):
|
||||||
|
diff = max_input_length - batch.max_input_length
|
||||||
|
batch.max_input_length+=diff
|
||||||
|
fill_value=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
for i in range(len(batch)):
|
||||||
|
batch.all_input_ids[i]=torch.nn.functional.pad(batch.all_input_ids[i], (0, 0, 0, diff), value=fill_value)[:batch.max_input_length]
|
||||||
|
batch.input_lengths[i]+=diff
|
||||||
|
batch.prefix_offsets[i] = 0
|
||||||
|
batch.read_offsets[i] = 0
|
||||||
|
|
||||||
|
batch.attention_mask[:, -batch.padding_right_offset:batch.padding_right_offset+diff] = 1
|
||||||
|
batch.padding_right_offset -= diff
|
||||||
|
|
||||||
|
if cache_dtype is None:
|
||||||
|
assert batch.input_ids.shape==(len(batch),batch.max_input_length-diff)
|
||||||
|
batch.input_ids=torch.nn.functional.pad(batch.input_ids, (0,diff), value=fill_value)[:, :batch.max_input_length]
|
||||||
|
batch.position_ids = batch.attention_mask[:, :batch.max_input_length].long().cumsum(-1) - 1
|
||||||
|
batch.past_key_values=None
|
||||||
|
else:
|
||||||
|
from transformers import GPTBigCodeForCausalLM
|
||||||
|
batch.input_ids=batch.input_ids[:,:1].fill_(fill_value)
|
||||||
|
batch.position_ids = batch.position_ids[:, -1:] + diff
|
||||||
|
if isinstance(self.model, GPTBigCodeForCausalLM):
|
||||||
|
batch.past_key_values=[
|
||||||
|
torch.randn(
|
||||||
|
[
|
||||||
|
len(batch),
|
||||||
|
batch.max_input_length - 1,
|
||||||
|
2 * self.model.config.n_embd // self.model.config.n_head,
|
||||||
|
],
|
||||||
|
dtype=cache_dtype,
|
||||||
|
device=batch.input_ids.device,
|
||||||
|
)
|
||||||
|
for _ in range(self.model.config.n_layer)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch.past_key_values=[
|
||||||
|
[torch.randn(
|
||||||
|
[
|
||||||
|
len(batch),
|
||||||
|
batch.max_input_length - 1,
|
||||||
|
self.model.config.n_embd // self.model.config.n_head,
|
||||||
|
],
|
||||||
|
dtype=cache_dtype,
|
||||||
|
device=batch.input_ids.device,
|
||||||
|
)
|
||||||
|
for _ in range(2)] for _ in range(self.model.config.n_layer)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
@ -515,6 +572,9 @@ class CausalLM(Model):
|
|||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
|
||||||
|
#print("AAA", batch.max_input_length, batch.input_ids, batch.position_ids, batch.all_input_ids, batch.all_input_ids[0].shape)
|
||||||
|
#print("BBB", batch.max_input_length, batch.input_lengths, batch.padding_right_offset, batch.past_key_values is None)
|
||||||
|
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
@ -146,9 +146,10 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
layer_past: torch.Tensor,
|
layer_past: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
batch_size: int,
|
|
||||||
key_length: int,
|
key_length: int,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
|
batch_size=hidden_states.size(0)
|
||||||
|
|
||||||
query, key_value = self.c_attn.forward(hidden_states).split(
|
query, key_value = self.c_attn.forward(hidden_states).split(
|
||||||
(self.embed_dim, 2 * self.head_dim), dim=-1
|
(self.embed_dim, 2 * self.head_dim), dim=-1
|
||||||
)
|
)
|
||||||
@ -183,7 +184,6 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
unscale = self.layer_idx + 1 if upcast else 1
|
unscale = self.layer_idx + 1 if upcast else 1
|
||||||
scale_factor = unscale**-1 / self.head_dim**0.5
|
scale_factor = unscale**-1 / self.head_dim**0.5
|
||||||
|
|
||||||
# TODO: No need to unsqueeze?
|
|
||||||
hidden_states = torch.baddbmm(
|
hidden_states = torch.baddbmm(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(batch_size, self.num_heads, padded_key_length),
|
(batch_size, self.num_heads, padded_key_length),
|
||||||
@ -194,7 +194,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
key.transpose(-1, -2),
|
key.transpose(-1, -2),
|
||||||
beta=0,
|
beta=0,
|
||||||
alpha=scale_factor,
|
alpha=scale_factor,
|
||||||
).unsqueeze_(1)
|
)
|
||||||
|
|
||||||
if upcast:
|
if upcast:
|
||||||
hidden_states = upcast_masked_softmax(
|
hidden_states = upcast_masked_softmax(
|
||||||
@ -205,7 +205,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
hidden_states, attention_mask, self.mask_value
|
hidden_states, attention_mask, self.mask_value
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
|
hidden_states = torch.bmm(hidden_states, value).view(query.shape)
|
||||||
|
|
||||||
hidden_states = self.c_proj.forward(hidden_states)
|
hidden_states = self.c_proj.forward(hidden_states)
|
||||||
|
|
||||||
@ -265,7 +265,6 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
layer_past: torch.Tensor,
|
layer_past: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
batch_size: int,
|
|
||||||
key_length: int,
|
key_length: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
||||||
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
|
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
|
||||||
@ -273,7 +272,6 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
batch_size=batch_size,
|
|
||||||
key_length=key_length,
|
key_length=key_length,
|
||||||
)
|
)
|
||||||
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
||||||
@ -370,12 +368,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
|
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Causal mask
|
|
||||||
self.causal_mask = torch.ones(
|
|
||||||
(config.max_position_embeddings, config.max_position_embeddings),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=device,
|
|
||||||
).tril_()
|
|
||||||
self.mask_value = torch.full(
|
self.mask_value = torch.full(
|
||||||
(), torch.finfo(torch.float32).min, dtype=torch.float32, device=device
|
(), torch.finfo(torch.float32).min, dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
@ -459,9 +451,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
key_length: int,
|
key_length: int,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
batch_size, query_length = input_ids.shape
|
|
||||||
assert query_length == 1
|
|
||||||
|
|
||||||
hidden_states = self.transformer.wte.forward(
|
hidden_states = self.transformer.wte.forward(
|
||||||
input_ids
|
input_ids
|
||||||
) + self.transformer.wpe.forward(position_ids)
|
) + self.transformer.wpe.forward(position_ids)
|
||||||
@ -469,13 +458,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
# Standardize shape to (batch_size, hidden_size)
|
# Standardize shape to (batch_size, hidden_size)
|
||||||
hidden_states.squeeze_(1)
|
hidden_states.squeeze_(1)
|
||||||
|
|
||||||
# Self-attention mask (padding + causal).
|
|
||||||
# TODO: Avoid unsqueeze
|
|
||||||
attention_mask = self.causal_mask[
|
|
||||||
None, key_length - 1 : key_length, : attention_mask.size(-1)
|
|
||||||
] * attention_mask.unsqueeze(1)
|
|
||||||
attention_mask.unsqueeze_(2)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
block: GPTBigCodeBlock
|
block: GPTBigCodeBlock
|
||||||
for i, (block, layer_past) in enumerate(
|
for i, (block, layer_past) in enumerate(
|
||||||
@ -486,11 +468,10 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
residual=residual,
|
residual=residual,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
batch_size=batch_size,
|
|
||||||
key_length=key_length,
|
key_length=key_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
|
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
|
||||||
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
hidden_states = self.lm_head.forward(hidden_states)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states.unsqueeze_(1), past_key_values
|
||||||
|
@ -13,12 +13,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch GPTBigCode model."""
|
"""PyTorch GPTBigCode model."""
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple, Any, Union, List
|
from typing import Optional, Tuple, Any, List
|
||||||
from enum import IntEnum
|
|
||||||
|
|
||||||
import torch
|
from torch import where, addmm, mm, float32, dtype, Tensor, baddbmm, empty, device, bmm, full, ones, finfo, jit
|
||||||
import torch.utils.checkpoint
|
from torch.nn import Linear, Embedding, Module, LayerNorm, ModuleList
|
||||||
from torch import nn
|
from torch.nn.functional import gelu, softmax, embedding
|
||||||
|
|
||||||
|
from dropout_layer_norm import dropout_add_ln_fwd
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
@ -27,357 +30,307 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InferenceRunnerType(IntEnum):
|
|
||||||
NO_RUNNER = 0
|
|
||||||
# Use the inference runner without cuda graphs.
|
|
||||||
BASE_RUNNER = 1
|
|
||||||
# Use cuda graphs in the inference runner. Leave out the attention which has a variable shape.
|
|
||||||
# This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models.
|
|
||||||
PARTIAL_GRAPH = 2
|
|
||||||
# Turn the whole model into a cuda graph. One graph for each sequence length.
|
|
||||||
# Note: only useful for small batches and models, graphs take some time to generate, flaky.
|
|
||||||
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
|
|
||||||
FULL_GRAPH = 3
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
|
||||||
except ImportError:
|
|
||||||
flash_attn_unpadded_func = None
|
|
||||||
pad_input = None
|
|
||||||
unpad_input = None
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@jit.script
|
||||||
def upcast_masked_softmax(
|
def upcast_masked_softmax(
|
||||||
x: torch.Tensor,
|
x: Tensor, mask: Tensor, mask_value: Tensor, scale: float
|
||||||
mask: torch.Tensor,
|
|
||||||
mask_value: torch.Tensor,
|
|
||||||
scale: float,
|
|
||||||
softmax_dtype: torch.dtype,
|
|
||||||
):
|
):
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
x = x.to(softmax_dtype) * scale
|
x = x.to(float32) * scale
|
||||||
x = torch.where(mask, x, mask_value)
|
x = where(mask, x, mask_value)
|
||||||
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
x = softmax(x, dim=-1).to(input_dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
class GPTBigCodeAttention(Module):
|
||||||
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
|
mask_value: Tensor
|
||||||
input_dtype = x.dtype
|
|
||||||
x = x.to(softmax_dtype) * scale
|
|
||||||
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: dtype):
|
||||||
@torch.jit.script
|
|
||||||
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
|
||||||
x = torch.where(mask, x, mask_value)
|
|
||||||
x = torch.nn.functional.softmax(x, dim=-1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@torch.profiler.record_function("softmax_function")
|
|
||||||
def softmax_function(
|
|
||||||
x: torch.Tensor,
|
|
||||||
mask: torch.Tensor,
|
|
||||||
mask_value: torch.Tensor,
|
|
||||||
scale: float,
|
|
||||||
softmax_dtype: torch.dtype,
|
|
||||||
upcast: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case
|
|
||||||
needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting
|
|
||||||
and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely
|
|
||||||
inefficient.
|
|
||||||
"""
|
|
||||||
# assert x.size(-1) % 8 == 0
|
|
||||||
if upcast:
|
|
||||||
if mask is None:
|
|
||||||
return upcast_softmax(x, scale, softmax_dtype)
|
|
||||||
else:
|
|
||||||
return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
|
|
||||||
else:
|
|
||||||
if mask is None:
|
|
||||||
return torch.nn.functional.softmax(x, dim=-1)
|
|
||||||
else:
|
|
||||||
return masked_softmax(x, mask, mask_value)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeAttention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
layer_idx=None,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mask_value = None
|
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
# KV caching and padding
|
self.c_attn = Linear(
|
||||||
|
|
||||||
self.c_attn = nn.Linear(
|
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.embed_dim + 2 * self.head_dim,
|
self.embed_dim + 2 * self.head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device="meta",
|
||||||
)
|
)
|
||||||
self.c_proj = nn.Linear(
|
self.c_proj = Linear(
|
||||||
self.embed_dim, self.embed_dim, dtype=dtype, device=device
|
self.embed_dim, self.embed_dim, dtype=dtype, device="meta"
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._get_mask_value")
|
class GPTBigCodeMLP(Module):
|
||||||
def _get_mask_value(self, device, dtype):
|
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
|
||||||
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
def __init__(self, config: GPTBigCodeConfig, dtype: dtype):
|
||||||
if (
|
super().__init__()
|
||||||
self.mask_value is None
|
embed_dim = config.hidden_size
|
||||||
or self.mask_value.dtype != dtype
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
|
||||||
or self.mask_value.device != device
|
self.c_fc = Linear(embed_dim, inner_dim, dtype=dtype, device="meta")
|
||||||
):
|
self.c_proj = Linear(inner_dim, embed_dim, dtype=dtype, device="meta")
|
||||||
self.mask_value = torch.full(
|
|
||||||
[], torch.finfo(dtype).min, dtype=dtype, device=device
|
|
||||||
|
class GPTBigCodeBlock(Module):
|
||||||
|
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.ln_1 = LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
)
|
)
|
||||||
return self.mask_value
|
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
|
||||||
|
self.ln_2 = LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._attn")
|
def post_load_weights(self, mask_value):
|
||||||
def _attn(self, query, key, value, attention_mask):
|
self.attn.mask_value = mask_value
|
||||||
softmax_dtype = torch.float32
|
|
||||||
upcast = query.dtype != softmax_dtype
|
|
||||||
|
|
||||||
unscale = self.layer_idx + 1 if upcast else 1
|
self.mask_value = mask_value
|
||||||
scale_factor = unscale**-1 / self.head_dim**0.5
|
self.hd=self.attn.head_dim
|
||||||
|
self.split0=(self.attn.embed_dim, 2*self.hd)
|
||||||
|
self.split1=(self.hd, self.hd)
|
||||||
|
|
||||||
# (batch_size, query_length, num_heads * head_dim)
|
self.aaw=self.attn.c_attn.weight.t_()
|
||||||
query_shape = query.shape
|
self.aab=self.attn.c_attn.bias
|
||||||
batch_size = query_shape[0]
|
self.apw=self.attn.c_proj.weight.t_()
|
||||||
key_length = key.size(-2)
|
self.apb=self.attn.c_proj.bias
|
||||||
|
|
||||||
key = key.transpose(-1, -2)
|
self.unscale=self.attn.layer_idx + 1
|
||||||
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
self.ps=self.hd**-0.5
|
||||||
# -> (batch_size, query_length, num_heads, key_length)
|
self.ds=self.unscale**-1 * self.ps
|
||||||
query_length = query_shape[1]
|
|
||||||
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
|
||||||
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
|
||||||
# No copy needed for MQA 2, or when layer_past is provided.
|
|
||||||
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
|
self.l1w=self.ln_1.weight
|
||||||
if query.device.type == "cpu":
|
self.l1b=self.ln_1.bias
|
||||||
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
|
self.e1=self.ln_1.eps
|
||||||
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
|
self.l2w=self.ln_2.weight
|
||||||
# but the fix has not been released as of pytorch version 2.0.0.
|
self.l2b=self.ln_2.bias
|
||||||
attn_weights.zero_()
|
self.e2=self.ln_2.eps
|
||||||
beta = 1
|
self.mfb=self.mlp.c_fc.bias
|
||||||
|
self.mfw=self.mlp.c_fc.weight.t_()
|
||||||
|
self.mfb=self.mlp.c_fc.bias
|
||||||
|
self.mpw=self.mlp.c_proj.weight.t_()
|
||||||
|
self.mpb=self.mlp.c_proj.bias
|
||||||
|
|
||||||
|
|
||||||
|
def prefill(
|
||||||
|
self,
|
||||||
|
hidden_states: Tensor,
|
||||||
|
residual: Optional[Tensor],
|
||||||
|
sequence_lengths,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple[Tensor, Tensor, Any]:
|
||||||
|
if residual is None: # First layer
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.l1w,
|
||||||
|
self.l1b,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e1,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
beta = 0
|
hidden_states, residual, *_ = dropout_add_ln_fwd(
|
||||||
attn_weights = torch.baddbmm(
|
hidden_states,
|
||||||
attn_weights, query, key, beta=beta, alpha=scale_factor
|
residual,
|
||||||
).view(attn_shape)
|
self.l1w,
|
||||||
|
self.l1b,
|
||||||
attn_weights = softmax_function(
|
None,
|
||||||
attn_weights,
|
None,
|
||||||
attention_mask,
|
None,
|
||||||
None
|
None,
|
||||||
if attention_mask is None
|
0.0,
|
||||||
else self._get_mask_value(attn_weights.device, softmax_dtype),
|
self.e1,
|
||||||
unscale,
|
1.0,
|
||||||
softmax_dtype,
|
0,
|
||||||
upcast,
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
hidden_shape = hidden_states.shape
|
||||||
|
query, key_value = addmm(self.aab, hidden_states, self.aaw).split(self.split0, dim=-1)
|
||||||
|
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
|
||||||
|
key, value = (
|
||||||
|
key_value.unsqueeze(1)
|
||||||
|
.expand(hidden_shape[0], self.num_heads, 2 * self.head_dim)
|
||||||
|
.split(self.split1, dim=-1)
|
||||||
)
|
)
|
||||||
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._attn_flash")
|
|
||||||
def _attn_flash(self, query, key, value, flash_params):
|
|
||||||
query_shape = query.shape
|
|
||||||
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
|
||||||
query = query.view(attn_shape)
|
|
||||||
key = key.unsqueeze(1).expand(attn_shape)
|
|
||||||
value = value.unsqueeze(1).expand(attn_shape)
|
|
||||||
|
|
||||||
sequence_lengths, padding_index, _, max_sequence_length = flash_params
|
|
||||||
|
|
||||||
# attn_output: (sum_seq_len, num_heads * head_dim)
|
# attn_output: (sum_seq_len, num_heads * head_dim)
|
||||||
attn_output = flash_attn_unpadded_func(
|
hidden_states = flash_attn_unpadded_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
sequence_lengths,
|
sequence_lengths,
|
||||||
sequence_lengths,
|
sequence_lengths,
|
||||||
max_sequence_length,
|
key_length,
|
||||||
max_sequence_length,
|
key_length,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale=self.head_dim**-0.5,
|
softmax_scale=self.ps,
|
||||||
causal=True,
|
causal=True,
|
||||||
).view(query_shape)
|
).view(hidden_shape)
|
||||||
|
hidden_states = addmm(self.apb, hidden_states, self.apw, out=query)
|
||||||
|
|
||||||
return attn_output
|
hidden_states, residual, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches")
|
residual,
|
||||||
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
|
self.l2w,
|
||||||
# Convert to standard KV cache format.
|
self.l2b,
|
||||||
if flash_params is not None:
|
None,
|
||||||
_, padding_index, batch_size, max_sequence_length = flash_params
|
None,
|
||||||
current_kv_cache = pad_input(
|
None,
|
||||||
key_value, padding_index, batch_size, max_sequence_length
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e2,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
return key_value, (current_kv_cache, max_sequence_length)
|
# TODO: Find an inplace and/or fused (with addmm) gelu kernel?
|
||||||
|
hidden_states = addmm(self.mpb, gelu(addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states)
|
||||||
|
return hidden_states, residual, key_value
|
||||||
|
|
||||||
current_kv_cache = key_value
|
def decode(
|
||||||
|
self,
|
||||||
|
hidden_states: Tensor,
|
||||||
|
residual: Optional[Tensor],
|
||||||
|
layer_past: Tensor,
|
||||||
|
attention_mask: Tensor,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple[Tensor, Tensor, Any]:
|
||||||
|
|
||||||
|
batch_size=hidden_states.size(0)
|
||||||
|
|
||||||
|
if residual is None: # First layer
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.l1w,
|
||||||
|
self.l1b,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e1,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, residual, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.l1w,
|
||||||
|
self.l1b,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e1,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
query, key_value = addmm(self.aab, hidden_states, self.aaw).split(self.split0, dim=-1)
|
||||||
|
query_view = query.view(batch_size, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
# Calculate dimensions and recover layer_past
|
# Calculate dimensions and recover layer_past
|
||||||
batch_size = current_kv_cache.size(0)
|
|
||||||
query_length = current_kv_cache.size(-2)
|
|
||||||
if layer_past is None:
|
|
||||||
allocated_kv_cache, last_key_length = None, 0
|
|
||||||
last_kv_cache = None
|
|
||||||
key_length = query_length
|
|
||||||
allocated_key_length = key_length
|
|
||||||
else:
|
|
||||||
allocated_kv_cache, last_key_length = layer_past
|
|
||||||
last_kv_cache = allocated_kv_cache[:, :last_key_length]
|
|
||||||
key_length = query_length + last_key_length
|
|
||||||
allocated_key_length = allocated_kv_cache.size(-2)
|
|
||||||
|
|
||||||
padded_key_length = attention_mask.size(-1)
|
padded_key_length = attention_mask.size(-1)
|
||||||
|
allocated_key_length = layer_past.size(-2)
|
||||||
|
|
||||||
# Re-allocate kv cache and copy last value
|
# TODO: Allow pre-allocation with size > padded_key_length
|
||||||
if padded_key_length > allocated_key_length:
|
if padded_key_length > allocated_key_length:
|
||||||
allocated_kv_cache = torch.empty(
|
# Re-allocate kv cache and copy last value
|
||||||
[batch_size, padded_key_length, 2 * self.head_dim],
|
allocated_kv_cache = empty(
|
||||||
dtype=current_kv_cache.dtype,
|
[batch_size, padded_key_length, 2*self.hd],
|
||||||
device=current_kv_cache.device,
|
dtype=key_value.dtype,
|
||||||
|
device=key_value.device,
|
||||||
|
)
|
||||||
|
allocated_kv_cache[:, : key_length - 1].copy_(
|
||||||
|
layer_past[:, : key_length - 1]
|
||||||
)
|
)
|
||||||
if layer_past is not None:
|
|
||||||
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
|
|
||||||
if padded_key_length > key_length:
|
|
||||||
# Nans in `value` can propagate through the matrix multiplication,
|
# Nans in `value` can propagate through the matrix multiplication,
|
||||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||||
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
|
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
|
||||||
|
layer_past = allocated_kv_cache
|
||||||
|
|
||||||
# Copy the new values.
|
# Copy the new values.
|
||||||
if padded_key_length > allocated_key_length or layer_past is not None:
|
layer_past[:, key_length - 1].copy_(key_value)
|
||||||
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
|
|
||||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
|
||||||
# Use the merged KV cache.
|
|
||||||
# Not needed when layer_past is None but frees some memory.
|
|
||||||
key_value = padded_kv_cache
|
|
||||||
|
|
||||||
if allocated_kv_cache is None:
|
key, value = layer_past.split(self.split1, dim=-1)
|
||||||
allocated_kv_cache = current_kv_cache
|
|
||||||
present = allocated_kv_cache, key_length
|
|
||||||
return key_value, present
|
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention.forward")
|
# Assume we always upcast (optimized for fp16/bf16)
|
||||||
def forward(
|
# TODO: Upcasting needed for bf16?
|
||||||
self,
|
hidden_states = baddbmm(
|
||||||
hidden_states: torch.Tensor,
|
empty(
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
(batch_size, self.num_heads, padded_key_length),
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
device=query.device,
|
||||||
flash_params: Optional[Tuple] = None,
|
dtype=query.dtype,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
),
|
||||||
query, key_value = self.c_attn(hidden_states).split(
|
query_view,
|
||||||
(self.embed_dim, 2 * self.head_dim), dim=-1
|
key.transpose(-1, -2),
|
||||||
|
beta=0,
|
||||||
|
alpha=self.ds,
|
||||||
|
)
|
||||||
|
hidden_states = upcast_masked_softmax(
|
||||||
|
hidden_states, attention_mask, self.mask_value, self.unscale
|
||||||
)
|
)
|
||||||
|
|
||||||
# present = (allocated_kv_cache, key_length)
|
# TODO: Write attn output directly into query, avoids both allocation and view.
|
||||||
key_value, present = self._merge_kv_caches(
|
bmm(hidden_states.squeeze_(1), value, out=query_view)
|
||||||
key_value, layer_past, attention_mask, flash_params
|
# TODO: Reuse attn weight tensor for c_proj output?
|
||||||
|
hidden_states = addmm(self.apb, query, self.apw)
|
||||||
|
|
||||||
|
hidden_states, residual, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.l2w,
|
||||||
|
self.l2b,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e2,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
# TODO: Reuse attn weight tensor for c_fc output? (ok if padded_key_length>=4*head_dim, otherwise need to allocate a bigger one).
|
||||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
# TODO: Find an inplace and/or fused (with addmm) gelu kernel?
|
||||||
|
hidden_states = addmm(self.mpb, gelu(addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states)
|
||||||
if flash_params is None:
|
return hidden_states, residual, layer_past
|
||||||
attn_output = self._attn(query, key, value, attention_mask)
|
|
||||||
else:
|
|
||||||
attn_output = self._attn_flash(query, key, value, flash_params)
|
|
||||||
|
|
||||||
attn_output = self.c_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, present
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeMLP(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
embed_dim = config.hidden_size
|
|
||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
|
|
||||||
self.c_fc = nn.Linear(embed_dim, inner_dim)
|
|
||||||
self.c_proj = nn.Linear(inner_dim, embed_dim)
|
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeMLP.forward")
|
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
|
||||||
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
|
|
||||||
return self.c_proj(
|
|
||||||
nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
layer_idx=None,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ln_1 = nn.LayerNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self.attn = GPTBigCodeAttention(
|
|
||||||
config, layer_idx=layer_idx, dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
self.ln_2 = nn.LayerNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self.mlp = GPTBigCodeMLP(config)
|
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeBlock.forward")
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
flash_params: Optional[Tuple] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
|
||||||
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
|
|
||||||
ai = self.ln_1(hidden_states)
|
|
||||||
attn_output, present = self.attn(
|
|
||||||
ai,
|
|
||||||
layer_past=layer_past,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
flash_params=flash_params,
|
|
||||||
)
|
|
||||||
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
|
|
||||||
hidden_states.add_(attn_output)
|
|
||||||
|
|
||||||
with torch.profiler.record_function("GPTBigCodeAttention.dummy"):
|
|
||||||
pass
|
|
||||||
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
|
|
||||||
ai = self.ln_2(hidden_states)
|
|
||||||
ai = self.mlp(ai)
|
|
||||||
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
|
|
||||||
hidden_states.add_(ai)
|
|
||||||
return hidden_states, present
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||||
@ -396,9 +349,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
if isinstance(module, GPTBigCodeModel):
|
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
|
||||||
module.bias.fill_(True).tril_()
|
|
||||||
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
|
|
||||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||||
@ -412,215 +363,188 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
module.c_proj._is_hf_initialized = True
|
module.c_proj._is_hf_initialized = True
|
||||||
elif isinstance(module, nn.Linear):
|
elif isinstance(module, Linear):
|
||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
elif isinstance(module, nn.Embedding):
|
elif isinstance(module, Embedding):
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, LayerNorm):
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||||
def __init__(
|
# TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict)
|
||||||
self,
|
def __init__(self, config: GPTBigCodeConfig, dtype: dtype):
|
||||||
config,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
):
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.wte = nn.Embedding(
|
self.wte = Embedding(
|
||||||
config.vocab_size, config.hidden_size, dtype=dtype, device=device
|
config.vocab_size, config.hidden_size, dtype=dtype, device="meta"
|
||||||
)
|
)
|
||||||
self.wpe = nn.Embedding(
|
self.wpe = Embedding(
|
||||||
config.max_position_embeddings,
|
config.max_position_embeddings,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device="meta",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
self.h = ModuleList(
|
||||||
[
|
[
|
||||||
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device)
|
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.ln_f = nn.LayerNorm(
|
self.ln_f = LayerNorm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device="meta",
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.inference_runner_type = (
|
|
||||||
InferenceRunnerType.NO_RUNNER
|
|
||||||
) # InferenceRunnerType(config.inference_runner)
|
|
||||||
|
|
||||||
self.flash_attention = True # config.flash_attention
|
|
||||||
|
|
||||||
if self.flash_attention:
|
|
||||||
if flash_attn_unpadded_func is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Flash Attention requires `flash_attn` and `einops`. "
|
|
||||||
"To install, run `pip install flash-attn einops`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.inference_runner_type == InferenceRunnerType.NO_RUNNER:
|
|
||||||
self.inference_runner = None
|
|
||||||
else:
|
|
||||||
from .inference_runner import GPTBigCodeInferenceRunner
|
|
||||||
|
|
||||||
self.inference_runner = GPTBigCodeInferenceRunner(config, self)
|
|
||||||
|
|
||||||
# Causal mask
|
|
||||||
self.register_buffer(
|
|
||||||
"bias",
|
|
||||||
torch.empty(
|
|
||||||
(config.max_position_embeddings, config.max_position_embeddings),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask")
|
|
||||||
def _get_causal_mask(self, padding_mask, query_length, key_length):
|
|
||||||
# Self-attention mask.
|
|
||||||
attention_mask = self.bias[
|
|
||||||
None, key_length - query_length : key_length, :key_length
|
|
||||||
]
|
|
||||||
|
|
||||||
if padding_mask is not None:
|
|
||||||
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
|
|
||||||
dtype=torch.bool, device=attention_mask.device
|
|
||||||
)
|
|
||||||
pad = -key_length % 8
|
|
||||||
if pad > 0:
|
|
||||||
attention_mask = torch.nn.functional.pad(
|
|
||||||
attention_mask, (0, pad), mode="constant", value=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# (batch_size, query_length, n_heads, key_length)
|
|
||||||
return attention_mask.unsqueeze(2)
|
|
||||||
|
|
||||||
# @torch.profiler.record_function("GPTBigCodeModel.forward")
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
) -> Tuple:
|
|
||||||
if self.inference_runner is not None and past_key_values is not None:
|
|
||||||
if self.config.validate_runner_input:
|
|
||||||
assert past_key_values is not None
|
|
||||||
return self.inference_runner.forward(
|
|
||||||
input_ids, attention_mask, position_ids, past_key_values
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size, query_length = input_ids.shape
|
|
||||||
|
|
||||||
flash_attention = self.flash_attention and past_key_values is None
|
|
||||||
if past_key_values is None:
|
|
||||||
past_length = 0
|
|
||||||
past_key_values = tuple([None] * len(self.h))
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][1]
|
|
||||||
|
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
|
||||||
|
|
||||||
# TODO: Unpad earlier (input ids), support unpadded input?
|
|
||||||
if flash_attention:
|
|
||||||
(
|
|
||||||
hidden_states,
|
|
||||||
padding_index,
|
|
||||||
sequence_lengths,
|
|
||||||
max_sequence_length,
|
|
||||||
) = unpad_input(hidden_states, attention_mask)
|
|
||||||
flash_params = (
|
|
||||||
sequence_lengths,
|
|
||||||
padding_index,
|
|
||||||
batch_size,
|
|
||||||
max_sequence_length,
|
|
||||||
)
|
|
||||||
attention_mask = None
|
|
||||||
else:
|
|
||||||
key_length = past_length + query_length
|
|
||||||
# Self-attention mask (padding + causal).
|
|
||||||
attention_mask = self._get_causal_mask(
|
|
||||||
attention_mask, query_length, key_length
|
|
||||||
)
|
|
||||||
flash_params = None
|
|
||||||
|
|
||||||
presents = []
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
||||||
outputs = block(
|
|
||||||
hidden_states,
|
|
||||||
layer_past=layer_past,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
flash_params=flash_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
presents.append(outputs[1])
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
|
|
||||||
if flash_attention:
|
|
||||||
hidden_states = pad_input(
|
|
||||||
hidden_states, padding_index, batch_size, query_length
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states, presents
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, config, dtype: dtype, device: device = device("cuda")
|
||||||
config,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
):
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
meta = torch.device("meta")
|
if device.type != "cuda":
|
||||||
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
|
raise NotImplementedError(f"Device {device} not supported")
|
||||||
self.lm_head = nn.Linear(
|
|
||||||
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta
|
self.transformer = GPTBigCodeModel(config, dtype=dtype)
|
||||||
|
self.lm_head = Linear(
|
||||||
|
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
|
||||||
|
)
|
||||||
|
self.mask_value = full(
|
||||||
|
(), finfo(float32).min, dtype=float32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.to_empty(device=device)
|
self.to_empty(device=device)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
# TODO: Skip?
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
# @torch.profiler.record_function("GPTBigCodeForCausalLM.forward")
|
def post_load_weights(self):
|
||||||
def forward(
|
layer: GPTBigCodeBlock
|
||||||
|
for layer in self.transformer.h:
|
||||||
|
layer.post_load_weights(self.mask_value)
|
||||||
|
self.tw=self.transformer.wte.weight
|
||||||
|
self.pw=self.transformer.wpe.weight
|
||||||
|
self.hw=self.lm_head.weight.t_()
|
||||||
|
self.lw=self.transformer.ln_f.weight
|
||||||
|
self.lb=self.transformer.ln_f.bias
|
||||||
|
self.le=self.transformer.ln_f.eps
|
||||||
|
|
||||||
|
def prefill(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
input_ids: torch.Tensor,
|
input_ids: Tensor,
|
||||||
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
|
attention_mask: Tensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
position_ids: Tensor,
|
||||||
position_ids: torch.Tensor,
|
|
||||||
predict_all_tokens: bool = True,
|
predict_all_tokens: bool = True,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
hidden_states, presents = self.transformer(
|
batch_size, query_length = input_ids.shape
|
||||||
input_ids=input_ids,
|
|
||||||
past_key_values=past_key_values,
|
hidden_states = embedding(input_ids, self.tw).add_(embedding(position_ids, self.pw))
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
# Prefill (flash attn)
|
||||||
|
# TODO: Unpad earlier (input ids)?
|
||||||
|
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(
|
||||||
|
hidden_states, attention_mask
|
||||||
)
|
)
|
||||||
# with torch.profiler.record_function("GPTBigCodeForCausalLM.head"):
|
assert key_length == query_length
|
||||||
if not predict_all_tokens:
|
|
||||||
# We only care about the last token.
|
|
||||||
hidden_states = hidden_states[:, -1:]
|
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
residual = None
|
||||||
|
past_key_values = []
|
||||||
|
block: GPTBigCodeBlock
|
||||||
|
for block in self.transformer.h:
|
||||||
|
hidden_states, residual, key_value = block.prefill(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
sequence_lengths,
|
||||||
|
key_length,
|
||||||
|
)
|
||||||
|
past_key_values.append(
|
||||||
|
pad_input(key_value, padding_index, batch_size, query_length)
|
||||||
|
)
|
||||||
|
|
||||||
return lm_logits, presents
|
hidden_states, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.lw,
|
||||||
|
self.lb,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.le,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
|
||||||
|
del residual
|
||||||
|
|
||||||
|
if predict_all_tokens:
|
||||||
|
hidden_states = pad_input(
|
||||||
|
mm(hidden_states, self.hw), padding_index, batch_size, query_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO: Index directly using cu_seqlens instead
|
||||||
|
hidden_states = mm(pad_input(
|
||||||
|
hidden_states, padding_index, batch_size, query_length
|
||||||
|
)[:, -1], self.hw).unsqueeze_(1)
|
||||||
|
|
||||||
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input_ids: Tensor,
|
||||||
|
past_key_values: List[Tensor],
|
||||||
|
attention_mask: [Tensor],
|
||||||
|
position_ids: Tensor,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple:
|
||||||
|
hidden_states = embedding(input_ids, self.tw).add_(embedding(position_ids, self.pw))
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
block: GPTBigCodeBlock
|
||||||
|
for i, (block, layer_past) in enumerate(
|
||||||
|
zip(self.transformer.h, past_key_values)
|
||||||
|
):
|
||||||
|
hidden_states, residual, past_key_values[i] = block.decode(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
layer_past,
|
||||||
|
attention_mask,
|
||||||
|
key_length,
|
||||||
|
)
|
||||||
|
hidden_states, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.lw,
|
||||||
|
self.lb,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.le,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
# TODO: Reuse residual?
|
||||||
|
return mm(hidden_states, self.hw).unsqueeze_(1), past_key_values
|
||||||
|
@ -0,0 +1,565 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch GPTBigCode model."""
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple, Any, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from dropout_layer_norm import dropout_add_ln_fwd
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
|
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
||||||
|
GPTBigCodeConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
# TODO: Validate dimension
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
out, residual, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
return out, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FastLinear(nn.Linear):
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.addmm(self.bias, input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class FastLinearNoBias(nn.Linear):
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.mm(input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def upcast_masked_softmax(
|
||||||
|
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float
|
||||||
|
):
|
||||||
|
input_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32) * scale
|
||||||
|
x = torch.where(mask, x, mask_value)
|
||||||
|
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
||||||
|
x = torch.where(mask, x, mask_value)
|
||||||
|
x = torch.nn.functional.softmax(x, dim=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBigCodeAttention(nn.Module):
|
||||||
|
mask_value: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.c_attn = FastLinear(
|
||||||
|
self.embed_dim,
|
||||||
|
self.embed_dim + 2 * self.head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
self.c_proj = FastLinear(
|
||||||
|
self.embed_dim, self.embed_dim, dtype=dtype, device="meta"
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.profiler.record_function("GPTBigCodeAttention.prefill")
|
||||||
|
def prefill(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sequence_lengths,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
|
hidden_shape = hidden_states.shape
|
||||||
|
query, key_value = self.c_attn.forward(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.head_dim), dim=-1
|
||||||
|
)
|
||||||
|
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
|
||||||
|
key, value = (
|
||||||
|
key_value.unsqueeze(1)
|
||||||
|
.expand(hidden_shape[0], self.num_heads, 2 * self.head_dim)
|
||||||
|
.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# attn_output: (sum_seq_len, num_heads * head_dim)
|
||||||
|
hidden_states = flash_attn_unpadded_func(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
sequence_lengths,
|
||||||
|
sequence_lengths,
|
||||||
|
key_length,
|
||||||
|
key_length,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=self.head_dim**-0.5,
|
||||||
|
causal=True,
|
||||||
|
).view(hidden_shape)
|
||||||
|
|
||||||
|
hidden_states = self.c_proj.forward(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, key_value
|
||||||
|
|
||||||
|
@torch.profiler.record_function("GPTBigCodeAttention.decode")
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
layer_past: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
|
query, key_value = self.c_attn.forward(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.head_dim), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate dimensions and recover layer_past
|
||||||
|
padded_key_length = attention_mask.size(-1)
|
||||||
|
allocated_key_length = layer_past.size(-2)
|
||||||
|
|
||||||
|
# TODO: Allow pre-allocation with size > padded_key_length
|
||||||
|
if padded_key_length > allocated_key_length:
|
||||||
|
# Re-allocate kv cache and copy last value
|
||||||
|
allocated_kv_cache = torch.empty(
|
||||||
|
[batch_size, padded_key_length, 2 * self.head_dim],
|
||||||
|
dtype=key_value.dtype,
|
||||||
|
device=key_value.device,
|
||||||
|
)
|
||||||
|
allocated_kv_cache[:, : key_length - 1].copy_(
|
||||||
|
layer_past[:, : key_length - 1]
|
||||||
|
)
|
||||||
|
# Nans in `value` can propagate through the matrix multiplication,
|
||||||
|
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||||
|
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
|
||||||
|
layer_past = allocated_kv_cache
|
||||||
|
|
||||||
|
# Copy the new values.
|
||||||
|
layer_past[:, key_length - 1].copy_(key_value)
|
||||||
|
|
||||||
|
key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
|
# TODO: Upcasting needed for bf16?
|
||||||
|
upcast = query.dtype != torch.float32
|
||||||
|
unscale = self.layer_idx + 1 if upcast else 1
|
||||||
|
scale_factor = unscale**-1 / self.head_dim**0.5
|
||||||
|
|
||||||
|
# TODO: No need to unsqueeze?
|
||||||
|
hidden_states = torch.baddbmm(
|
||||||
|
torch.empty(
|
||||||
|
(batch_size, self.num_heads, padded_key_length),
|
||||||
|
device=query.device,
|
||||||
|
dtype=query.dtype,
|
||||||
|
),
|
||||||
|
query.view(batch_size, self.num_heads, self.head_dim),
|
||||||
|
key.transpose(-1, -2),
|
||||||
|
beta=0,
|
||||||
|
alpha=scale_factor,
|
||||||
|
).unsqueeze_(1)
|
||||||
|
|
||||||
|
if upcast:
|
||||||
|
hidden_states = upcast_masked_softmax(
|
||||||
|
hidden_states, attention_mask, self.mask_value, unscale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = masked_softmax(
|
||||||
|
hidden_states, attention_mask, self.mask_value
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
|
||||||
|
|
||||||
|
hidden_states = self.c_proj.forward(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, layer_past
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBigCodeMLP(nn.Module):
|
||||||
|
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
|
||||||
|
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
|
||||||
|
self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta")
|
||||||
|
self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBigCodeBlock(nn.Module):
|
||||||
|
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.ln_1 = FastLayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
|
||||||
|
self.ln_2 = FastLayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
|
||||||
|
|
||||||
|
def post_load_weights(self, mask_value):
|
||||||
|
self.attn.mask_value = mask_value
|
||||||
|
self.attn.c_attn.weight.t_()
|
||||||
|
self.attn.c_proj.weight.t_()
|
||||||
|
|
||||||
|
self.l1w=self.ln_1.weight
|
||||||
|
self.l1b=self.ln_1.bias
|
||||||
|
self.e1=self.ln_1.eps
|
||||||
|
self.l2w=self.ln_2.weight
|
||||||
|
self.l2b=self.ln_2.bias
|
||||||
|
self.e2=self.ln_2.eps
|
||||||
|
self.mfb=self.mlp.c_fc.bias
|
||||||
|
self.mfw=self.mlp.c_fc.weight.t_()
|
||||||
|
self.mfb=self.mlp.c_fc.bias
|
||||||
|
self.mpw=self.mlp.c_proj.weight.t_()
|
||||||
|
self.mpb=self.mlp.c_proj.bias
|
||||||
|
|
||||||
|
|
||||||
|
@torch.profiler.record_function("GPTBigCodeBlock.prefill")
|
||||||
|
def prefill(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
sequence_lengths,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
||||||
|
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
|
||||||
|
hidden_states, present = self.attn.prefill(
|
||||||
|
hidden_states,
|
||||||
|
sequence_lengths,
|
||||||
|
key_length,
|
||||||
|
)
|
||||||
|
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp.c_proj.forward(
|
||||||
|
nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")
|
||||||
|
)
|
||||||
|
return hidden_states, residual, present
|
||||||
|
|
||||||
|
@torch.profiler.record_function("GPTBigCodeBlock.decode")
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
layer_past: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.l1w,
|
||||||
|
self.l1b,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e1,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, residual, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.l1w,
|
||||||
|
self.l1b,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e1,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, present = self.attn.decode(
|
||||||
|
hidden_states,
|
||||||
|
layer_past,
|
||||||
|
attention_mask,
|
||||||
|
batch_size,
|
||||||
|
key_length,
|
||||||
|
)
|
||||||
|
hidden_states, residual, *_ = dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.l2w,
|
||||||
|
self.l2b,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.e2,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
hidden_states = torch.addmm(self.mpb, nn.functional.gelu(torch.addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states)
|
||||||
|
return hidden_states, residual, present
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = GPTBigCodeConfig
|
||||||
|
base_model_prefix = "transformer"
|
||||||
|
supports_gradient_checkpointing = False
|
||||||
|
_no_split_modules = ["GPTBigCodeBlock"]
|
||||||
|
|
||||||
|
def __init__(self, *inputs, **kwargs):
|
||||||
|
super().__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights."""
|
||||||
|
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
|
||||||
|
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||||
|
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||||
|
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||||
|
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||||
|
#
|
||||||
|
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||||
|
module.c_proj.weight.data.normal_(
|
||||||
|
mean=0.0,
|
||||||
|
std=(
|
||||||
|
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
module.c_proj._is_hf_initialized = True
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||||
|
# TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict)
|
||||||
|
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.wte = nn.Embedding(
|
||||||
|
config.vocab_size, config.hidden_size, dtype=dtype, device="meta"
|
||||||
|
)
|
||||||
|
self.wpe = nn.Embedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.h = nn.ModuleList(
|
||||||
|
[
|
||||||
|
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln_f = FastLayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config, dtype: torch.dtype, device: torch.device = torch.device("cuda")
|
||||||
|
):
|
||||||
|
super().__init__(config)
|
||||||
|
if device.type != "cuda":
|
||||||
|
raise NotImplementedError(f"Device {device} not supported")
|
||||||
|
|
||||||
|
self.transformer = GPTBigCodeModel(config, dtype=dtype)
|
||||||
|
self.lm_head = FastLinearNoBias(
|
||||||
|
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Causal mask
|
||||||
|
self.causal_mask = torch.ones(
|
||||||
|
(config.max_position_embeddings, config.max_position_embeddings),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
).tril_()
|
||||||
|
self.mask_value = torch.full(
|
||||||
|
(), torch.finfo(torch.float32).min, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_empty(device=device)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
# TODO: Skip?
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def prefill(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
predict_all_tokens: bool = True,
|
||||||
|
) -> Tuple:
|
||||||
|
batch_size, query_length = input_ids.shape
|
||||||
|
|
||||||
|
hidden_states = self.transformer.wte.forward(
|
||||||
|
input_ids
|
||||||
|
) + self.transformer.wpe.forward(position_ids)
|
||||||
|
|
||||||
|
# Prefill (flash attn)
|
||||||
|
# TODO: Unpad earlier (input ids)?
|
||||||
|
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(
|
||||||
|
hidden_states, attention_mask
|
||||||
|
)
|
||||||
|
assert key_length == query_length
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
past_key_values = []
|
||||||
|
block: GPTBigCodeBlock
|
||||||
|
for block in self.transformer.h:
|
||||||
|
hidden_states, residual, key_value = block.prefill(
|
||||||
|
hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
|
key_length=query_length,
|
||||||
|
)
|
||||||
|
past_key_values.append(
|
||||||
|
pad_input(key_value, padding_index, batch_size, query_length)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
|
||||||
|
|
||||||
|
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
|
||||||
|
del residual
|
||||||
|
|
||||||
|
if predict_all_tokens:
|
||||||
|
hidden_states = self.lm_head.forward(hidden_states)
|
||||||
|
hidden_states = pad_input(
|
||||||
|
hidden_states, padding_index, batch_size, query_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO: Index directly instead
|
||||||
|
hidden_states = pad_input(
|
||||||
|
hidden_states, padding_index, batch_size, query_length
|
||||||
|
)[:, -1]
|
||||||
|
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
||||||
|
|
||||||
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
def post_load_weights(self):
|
||||||
|
layer: GPTBigCodeBlock
|
||||||
|
for layer in self.transformer.h:
|
||||||
|
layer.post_load_weights(self.mask_value)
|
||||||
|
self.lm_head.weight.t_()
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
past_key_values: List[torch.Tensor],
|
||||||
|
attention_mask: [torch.Tensor],
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
key_length: int,
|
||||||
|
) -> Tuple:
|
||||||
|
batch_size, query_length = input_ids.shape
|
||||||
|
assert query_length == 1
|
||||||
|
|
||||||
|
hidden_states = self.transformer.wte.forward(
|
||||||
|
input_ids
|
||||||
|
) + self.transformer.wpe.forward(position_ids)
|
||||||
|
|
||||||
|
# Standardize shape to (batch_size, hidden_size)
|
||||||
|
hidden_states.squeeze_(1)
|
||||||
|
|
||||||
|
# Self-attention mask (padding + causal).
|
||||||
|
# TODO: Avoid unsqueeze
|
||||||
|
attention_mask = self.causal_mask[
|
||||||
|
None, key_length - 1 : key_length, : attention_mask.size(-1)
|
||||||
|
] * attention_mask.unsqueeze(1)
|
||||||
|
attention_mask.unsqueeze_(2)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
block: GPTBigCodeBlock
|
||||||
|
for i, (block, layer_past) in enumerate(
|
||||||
|
zip(self.transformer.h, past_key_values)
|
||||||
|
):
|
||||||
|
hidden_states, residual, past_key_values[i] = block.decode(
|
||||||
|
hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
batch_size=batch_size,
|
||||||
|
key_length=key_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
|
||||||
|
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
||||||
|
|
||||||
|
return hidden_states, past_key_values
|
@ -451,11 +451,17 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]):
|
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]):
|
||||||
diff = max_input_length - max(batch.input_lengths)
|
diff = max_input_length - max(batch.input_lengths)
|
||||||
|
fill_value=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
||||||
for i in range(len(batch)):
|
for i in range(len(batch)):
|
||||||
batch.input_lengths[i] += diff
|
batch.input_lengths[i] += diff
|
||||||
batch.prefix_offsets[i] = 0
|
for _ in range(diff):
|
||||||
batch.read_offsets[i] = 0
|
batch.all_input_ids[i].append(fill_value)
|
||||||
batch.all_input_ids[i] = batch.all_input_ids[i] + [self.tokenizer.pad_token_id] * diff if diff >= 0 else batch.all_input_ids[i][:diff]
|
# For some reason just resetting the offsets makes things way slower.
|
||||||
|
_, batch.prefix_offsets[i], batch.read_offsets[i] = self.decode_token(
|
||||||
|
batch.all_input_ids[i],
|
||||||
|
batch.prefix_offsets[i],
|
||||||
|
batch.read_offsets[i],
|
||||||
|
)
|
||||||
# TODO: Bug!?!
|
# TODO: Bug!?!
|
||||||
batch.stopping_criterias[i].current_tokens += diff
|
batch.stopping_criterias[i].current_tokens += diff
|
||||||
|
|
||||||
@ -473,19 +479,19 @@ class FlashCausalLM(Model):
|
|||||||
batch.past_key_values=None
|
batch.past_key_values=None
|
||||||
else:
|
else:
|
||||||
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
|
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
|
||||||
batch.input_ids.fill_(self.tokenizer.pad_token_id)
|
batch.input_ids.fill_(fill_value)
|
||||||
batch.position_ids += diff
|
batch.position_ids += diff
|
||||||
batch.cu_seqlens += diff * batch.cu_seqlens_q
|
batch.cu_seqlens += diff * batch.cu_seqlens_q
|
||||||
# TODO: Bug!?!
|
# TODO: Bug!?!
|
||||||
batch.max_seqlen += batch.max_seqlen + diff*len(batch)
|
batch.max_seqlen += diff*len(batch)
|
||||||
|
|
||||||
for i in range(len(batch)):
|
for i in range(len(batch)):
|
||||||
batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(self.tokenizer.pad_token_id)
|
batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(fill_value)
|
||||||
|
|
||||||
batch.past_key_values = batch.past_key_values = torch.randn(
|
batch.past_key_values = batch.past_key_values = torch.randn(
|
||||||
(
|
(
|
||||||
batch.past_key_values.shape[0],
|
batch.past_key_values.shape[0],
|
||||||
batch.past_key_values.shape[1] + len(batch.requests),
|
batch.past_key_values.shape[1] + diff*len(batch.requests),
|
||||||
*batch.past_key_values.shape[2:],
|
*batch.past_key_values.shape[2:],
|
||||||
), device=batch.past_key_values.device, dtype= cache_dtype
|
), device=batch.past_key_values.device, dtype= cache_dtype
|
||||||
)
|
)
|
||||||
|
@ -13,8 +13,15 @@ from text_generation_server.models.vectorized_causal_lm import (
|
|||||||
VectorizedCausalLMBatch,
|
VectorizedCausalLMBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import (
|
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import (
|
||||||
GPTBigCodeForCausalLM,
|
GPTBigCodeForCausalLM as GPTBigCode2ForCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.gpt_bigcode3_modeling import (
|
||||||
|
GPTBigCodeForCausalLM as GPTBigCode3ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.gpt_bigcode4_modeling import (
|
||||||
|
GPTBigCodeForCausalLM as GPTBigCode4ForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -96,8 +103,9 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
|
|||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class Bigcode2CausalLM(VectorizedCausalLM):
|
class Bigcode2CausalLMBase(VectorizedCausalLM):
|
||||||
model: GPTBigCodeForCausalLM
|
#model: GPTBigCode2ForCausalLM
|
||||||
|
_model_class:Type[PreTrainedModel]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -118,7 +126,7 @@ class Bigcode2CausalLM(VectorizedCausalLM):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
model = GPTBigCodeForCausalLM.from_pretrained(
|
model = self._model_class.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
@ -163,18 +171,18 @@ class Bigcode2CausalLM(VectorizedCausalLM):
|
|||||||
padded_key_length = (
|
padded_key_length = (
|
||||||
key_length + -key_length % batch.pad_key_length_to_multiple
|
key_length + -key_length % batch.pad_key_length_to_multiple
|
||||||
)
|
)
|
||||||
input_ids = batch.input_ids[:, key_length - 1 : key_length]
|
input_ids = batch.input_ids[:, key_length - 1]
|
||||||
# Model Forward
|
# Model Forward
|
||||||
logits, batch.past_key_values = self.model.decode(
|
logits, batch.past_key_values = self.model.decode(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=batch.attention_mask[:, :padded_key_length],
|
attention_mask=batch.attention_mask[:, None, :padded_key_length],
|
||||||
position_ids=batch.position_ids[:, key_length - 1 : key_length],
|
position_ids=batch.position_ids[:, key_length - 1],
|
||||||
past_key_values=batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
key_length=key_length,
|
key_length=key_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_ids, logprobs = batch.next_token_chooser(
|
next_token_ids, logprobs = batch.next_token_chooser(
|
||||||
input_ids, logits, batch.details
|
input_ids.unsqueeze(1), logits.unsqueeze(1), batch.details
|
||||||
)
|
)
|
||||||
# Update batch
|
# Update batch
|
||||||
# TODO: Why do we need all input ids?
|
# TODO: Why do we need all input ids?
|
||||||
@ -201,3 +209,14 @@ class Bigcode2CausalLM(VectorizedCausalLM):
|
|||||||
)
|
)
|
||||||
for _ in range(self.model.config.n_layer)
|
for _ in range(self.model.config.n_layer)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class Bigcode2CausalLM(Bigcode2CausalLMBase):
|
||||||
|
_model_class=GPTBigCode2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class Bigcode3CausalLM(Bigcode2CausalLMBase):
|
||||||
|
_model_class=GPTBigCode3ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class Bigcode4CausalLM(Bigcode2CausalLMBase):
|
||||||
|
_model_class = GPTBigCode4ForCausalLM
|
||||||
|
Loading…
Reference in New Issue
Block a user