mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fixes and format
This commit is contained in:
parent
0921fe6a2a
commit
a515fbde4c
@ -13,14 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch GPTBigCode model."""
|
"""PyTorch GPTBigCode model."""
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple, Any, Union, List
|
from typing import Optional, Tuple, Any, List
|
||||||
from enum import IntEnum
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from dropout_layer_norm import dropout_layer_norm
|
from dropout_layer_norm import dropout_add_ln_fwd
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
|
|
||||||
@ -30,10 +29,11 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
|||||||
GPTBigCodeConfig,
|
GPTBigCodeConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
# TODO: Validate dimension
|
# TODO: Validate dimension
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
return dropout_layer_norm.dropout_add_ln_fwd(
|
out, residual, *_ = dropout_add_ln_fwd(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
self.weight,
|
self.weight,
|
||||||
@ -50,18 +50,24 @@ class FastLayerNorm(nn.LayerNorm):
|
|||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
return out, residual
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(nn.Linear):
|
class FastLinear(nn.Linear):
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.addmm(self.bias, input, self.weight)
|
return torch.addmm(self.bias, input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
class FastLinearNoBias(nn.Linear):
|
class FastLinearNoBias(nn.Linear):
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.mm(input, self.weight)
|
return torch.mm(input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def upcast_masked_softmax(
|
def upcast_masked_softmax(
|
||||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float
|
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float
|
||||||
@ -72,24 +78,33 @@ def upcast_masked_softmax(
|
|||||||
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
||||||
x = torch.where(mask, x, mask_value)
|
x = torch.where(mask, x, mask_value)
|
||||||
x = torch.nn.functional.softmax(x, dim=-1)
|
x = torch.nn.functional.softmax(x, dim=-1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeAttention(nn.Module):
|
class GPTBigCodeAttention(nn.Module):
|
||||||
|
mask_value: torch.Tensor
|
||||||
|
|
||||||
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
|
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
# Note: Does not support module dtype conversion.
|
|
||||||
self.register_buffer("mask_value", torch.empty((), dtype=torch.float32, device="meta"))
|
|
||||||
|
|
||||||
self.c_attn = FastLinear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device="meta")
|
self.c_attn = FastLinear(
|
||||||
self.c_proj = FastLinear(self.embed_dim, self.embed_dim, dtype=dtype, device="meta")
|
self.embed_dim,
|
||||||
|
self.embed_dim + 2 * self.head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
self.c_proj = FastLinear(
|
||||||
|
self.embed_dim, self.embed_dim, dtype=dtype, device="meta"
|
||||||
|
)
|
||||||
|
|
||||||
def prefill(
|
def prefill(
|
||||||
self,
|
self,
|
||||||
@ -98,12 +113,18 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
key_length: int,
|
key_length: int,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
hidden_shape = hidden_states.shape
|
hidden_shape = hidden_states.shape
|
||||||
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
query, key_value = self.c_attn.forward(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.head_dim), dim=-1
|
||||||
|
)
|
||||||
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
|
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
|
||||||
key, value = key_value.unsqueeze(1).expand(hidden_shape[0], self.num_heads, 2*self.head_dim).split((self.head_dim, self.head_dim), dim=-1)
|
key, value = (
|
||||||
|
key_value.unsqueeze(1)
|
||||||
|
.expand(hidden_shape[0], self.num_heads, 2 * self.head_dim)
|
||||||
|
.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
# attn_output: (sum_seq_len, num_heads * head_dim)
|
# attn_output: (sum_seq_len, num_heads * head_dim)
|
||||||
attn_output = flash_attn_unpadded_func(
|
hidden_states = flash_attn_unpadded_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -116,9 +137,9 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
).view(hidden_shape)
|
).view(hidden_shape)
|
||||||
|
|
||||||
attn_output = self.c_proj.forward(attn_output)
|
hidden_states = self.c_proj.forward(hidden_states)
|
||||||
|
|
||||||
return attn_output, key_value
|
return hidden_states, key_value
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@ -128,7 +149,9 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
key_length: int,
|
key_length: int,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
query, key_value = self.c_attn.forward(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.head_dim), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate dimensions and recover layer_past
|
# Calculate dimensions and recover layer_past
|
||||||
padded_key_length = attention_mask.size(-1)
|
padded_key_length = attention_mask.size(-1)
|
||||||
@ -142,14 +165,16 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
dtype=key_value.dtype,
|
dtype=key_value.dtype,
|
||||||
device=key_value.device,
|
device=key_value.device,
|
||||||
)
|
)
|
||||||
allocated_kv_cache[:, :key_length-1].copy_(layer_past[:, :key_length-1])
|
allocated_kv_cache[:, : key_length - 1].copy_(
|
||||||
|
layer_past[:, : key_length - 1]
|
||||||
|
)
|
||||||
# Nans in `value` can propagate through the matrix multiplication,
|
# Nans in `value` can propagate through the matrix multiplication,
|
||||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||||
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
|
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
|
||||||
layer_past = allocated_kv_cache
|
layer_past = allocated_kv_cache
|
||||||
|
|
||||||
# Copy the new values.
|
# Copy the new values.
|
||||||
layer_past[:, key_length-1:key_length].copy_(key_value)
|
layer_past[:, key_length - 1].copy_(key_value)
|
||||||
|
|
||||||
key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
|
key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
@ -158,21 +183,27 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
unscale = self.layer_idx + 1 if upcast else 1
|
unscale = self.layer_idx + 1 if upcast else 1
|
||||||
scale_factor = unscale**-1 / self.head_dim**0.5
|
scale_factor = unscale**-1 / self.head_dim**0.5
|
||||||
|
|
||||||
|
# TODO: No need to unsqueeze?
|
||||||
hidden_states = torch.baddbmm(
|
hidden_states = torch.baddbmm(
|
||||||
torch.empty((batch_size, self.num_heads, padded_key_length), device=query.device, dtype=query.dtype),
|
torch.empty(
|
||||||
|
(batch_size, self.num_heads, padded_key_length),
|
||||||
|
device=query.device,
|
||||||
|
dtype=query.dtype,
|
||||||
|
),
|
||||||
query.view(batch_size, self.num_heads, self.head_dim),
|
query.view(batch_size, self.num_heads, self.head_dim),
|
||||||
key.transpose(-1, -2),
|
key.transpose(-1, -2),
|
||||||
beta=0,
|
beta=0,
|
||||||
alpha=scale_factor
|
alpha=scale_factor,
|
||||||
).unsqueeze_(1)
|
).unsqueeze_(1)
|
||||||
|
|
||||||
if self.mask_value is None or self.mask_value.device != hidden_states.device:
|
|
||||||
self.mask_value = torch.full([], torch.finfo(torch.float32).min, dtype=torch.float32, device=hidden_states.device)
|
|
||||||
|
|
||||||
if upcast:
|
if upcast:
|
||||||
hidden_states = upcast_masked_softmax(hidden_states, attention_mask, self.mask_value, unscale)
|
hidden_states = upcast_masked_softmax(
|
||||||
|
hidden_states, attention_mask, self.mask_value, unscale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = masked_softmax(hidden_states, attention_mask, self.mask_value)
|
hidden_states = masked_softmax(
|
||||||
|
hidden_states, attention_mask, self.mask_value
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
|
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
|
||||||
|
|
||||||
@ -180,6 +211,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
|
|
||||||
return hidden_states, layer_past
|
return hidden_states, layer_past
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeMLP(nn.Module):
|
class GPTBigCodeMLP(nn.Module):
|
||||||
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
|
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
|
||||||
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
|
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
|
||||||
@ -189,12 +221,23 @@ class GPTBigCodeMLP(nn.Module):
|
|||||||
self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta")
|
self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta")
|
||||||
self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta")
|
self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta")
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeBlock(nn.Module):
|
class GPTBigCodeBlock(nn.Module):
|
||||||
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
|
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta")
|
self.ln_1 = FastLayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
|
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
|
||||||
self.ln_2 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta")
|
self.ln_2 = FastLayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
|
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
|
||||||
|
|
||||||
def prefill(
|
def prefill(
|
||||||
@ -211,7 +254,9 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
key_length=key_length,
|
key_length=key_length,
|
||||||
)
|
)
|
||||||
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
||||||
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh"))
|
hidden_states = self.mlp.c_proj.forward(
|
||||||
|
nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")
|
||||||
|
)
|
||||||
return hidden_states, residual, present
|
return hidden_states, residual, present
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
@ -232,7 +277,9 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
key_length=key_length,
|
key_length=key_length,
|
||||||
)
|
)
|
||||||
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
||||||
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh"))
|
hidden_states = self.mlp.c_proj.forward(
|
||||||
|
nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")
|
||||||
|
)
|
||||||
return hidden_states, residual, present
|
return hidden_states, residual, present
|
||||||
|
|
||||||
|
|
||||||
@ -252,11 +299,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
if isinstance(module, GPTBigCodeModel):
|
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
|
||||||
module.bias.fill_(True).tril_()
|
|
||||||
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
|
|
||||||
if isinstance(module, GPTBigCodeAttention):
|
|
||||||
module.mask_value.fill_(torch.finfo(torch.float32).min)
|
|
||||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||||
@ -264,7 +307,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
#
|
#
|
||||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||||
module.c_proj.weight.data.normal_(
|
module.c_proj.weight.data.normal_(
|
||||||
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
mean=0.0,
|
||||||
|
std=(
|
||||||
|
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
module.c_proj._is_hf_initialized = True
|
module.c_proj._is_hf_initialized = True
|
||||||
elif isinstance(module, nn.Linear):
|
elif isinstance(module, nn.Linear):
|
||||||
@ -287,39 +333,59 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
|
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device="meta")
|
self.wte = nn.Embedding(
|
||||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device="meta")
|
config.vocab_size, config.hidden_size, dtype=dtype, device="meta"
|
||||||
|
)
|
||||||
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) for i in range(config.num_hidden_layers)])
|
self.wpe = nn.Embedding(
|
||||||
self.ln_f = FastLayerNorm(config.hidden_size, dtype=dtype, device="meta", eps=config.layer_norm_epsilon)
|
config.max_position_embeddings,
|
||||||
|
config.hidden_size,
|
||||||
# Causal mask
|
dtype=dtype,
|
||||||
self.register_buffer(
|
device="meta",
|
||||||
"causal_mask", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device="meta")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
self.h = nn.ModuleList(
|
||||||
pad_key_length_to_multiple=8
|
[
|
||||||
|
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln_f = FastLayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="meta",
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, config, dtype:torch.dtype, device:torch.device=torch.device("cuda")):
|
|
||||||
|
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config, dtype: torch.dtype, device: torch.device = torch.device("cuda")
|
||||||
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
if device.type != "cuda":
|
if device.type != "cuda":
|
||||||
raise NotImplementedError(f"Device {device} not supported")
|
raise NotImplementedError(f"Device {device} not supported")
|
||||||
|
|
||||||
self.transformer = GPTBigCodeModel(config, dtype=dtype)
|
self.transformer = GPTBigCodeModel(config, dtype=dtype)
|
||||||
self.lm_head = FastLinearNoBias(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta")
|
self.lm_head = FastLinearNoBias(
|
||||||
|
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Causal mask
|
||||||
|
self.causal_mask = torch.ones(
|
||||||
|
(config.max_position_embeddings, config.max_position_embeddings),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
).tril_()
|
||||||
|
self.mask_value = torch.full(
|
||||||
|
(), torch.finfo(torch.float32).min, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
self.to_empty(device=device)
|
self.to_empty(device=device)
|
||||||
|
|
||||||
self._apply=self._apply_not_allowed
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
# TODO: Skip?
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def _apply_not_allowed(self):
|
|
||||||
# Dtype or device conversion would break the model.
|
|
||||||
raise NotImplementedError("Device or dtype conversion not supported!")
|
|
||||||
|
|
||||||
def prefill(
|
def prefill(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -330,11 +396,15 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
batch_size, query_length = input_ids.shape
|
batch_size, query_length = input_ids.shape
|
||||||
|
|
||||||
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids)
|
hidden_states = self.transformer.wte.forward(
|
||||||
|
input_ids
|
||||||
|
) + self.transformer.wpe.forward(position_ids)
|
||||||
|
|
||||||
# Prefill (flash attn)
|
# Prefill (flash attn)
|
||||||
# TODO: Unpad earlier (input ids)?
|
# TODO: Unpad earlier (input ids)?
|
||||||
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(hidden_states, attention_mask)
|
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(
|
||||||
|
hidden_states, attention_mask
|
||||||
|
)
|
||||||
assert key_length == query_length
|
assert key_length == query_length
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
@ -347,23 +417,39 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
key_length=query_length,
|
key_length=query_length,
|
||||||
)
|
)
|
||||||
past_key_values.append(pad_input(key_value, padding_index, batch_size, query_length))
|
past_key_values.append(
|
||||||
|
pad_input(key_value, padding_index, batch_size, query_length)
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.transformer.ln_f.forward(hidden_states, residual)
|
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
|
||||||
|
|
||||||
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
|
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
|
||||||
del residual
|
del residual
|
||||||
|
|
||||||
if predict_all_tokens:
|
if predict_all_tokens:
|
||||||
hidden_states = self.lm_head.forward(hidden_states)
|
hidden_states = self.lm_head.forward(hidden_states)
|
||||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
|
hidden_states = pad_input(
|
||||||
|
hidden_states, padding_index, batch_size, query_length
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# TODO: Index directly instead
|
# TODO: Index directly instead
|
||||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)[:, -1]
|
hidden_states = pad_input(
|
||||||
|
hidden_states, padding_index, batch_size, query_length
|
||||||
|
)[:, -1]
|
||||||
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
def post_load_weights(self):
|
||||||
|
layer: GPTBigCodeBlock
|
||||||
|
for layer in self.transformer.h:
|
||||||
|
layer.attn.mask_value = self.mask_value
|
||||||
|
layer.attn.c_attn.weight.t_()
|
||||||
|
layer.attn.c_proj.weight.t_()
|
||||||
|
layer.mlp.c_fc.weight.t_()
|
||||||
|
layer.mlp.c_proj.weight.t_()
|
||||||
|
self.lm_head.weight.t_()
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -373,24 +459,28 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
key_length: int,
|
key_length: int,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
|
|
||||||
batch_size, query_length = input_ids.shape
|
batch_size, query_length = input_ids.shape
|
||||||
assert query_length == 1
|
assert query_length == 1
|
||||||
|
|
||||||
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids)
|
hidden_states = self.transformer.wte.forward(
|
||||||
|
input_ids
|
||||||
|
) + self.transformer.wpe.forward(position_ids)
|
||||||
|
|
||||||
# Standardize shape to (batch_size, hidden_size)
|
# Standardize shape to (batch_size, hidden_size)
|
||||||
hidden_states.squeeze_(1)
|
hidden_states.squeeze_(1)
|
||||||
|
|
||||||
# Self-attention mask (padding + causal).
|
# Self-attention mask (padding + causal).
|
||||||
# TODO: Avoid unsqueeze
|
# TODO: Avoid unsqueeze
|
||||||
attention_mask = self.transformer.causal_mask[None, key_length - 1: key_length,
|
attention_mask = self.causal_mask[
|
||||||
:key_length] * attention_mask.unsqueeze(1)
|
None, key_length - 1 : key_length, : attention_mask.size(-1)
|
||||||
|
] * attention_mask.unsqueeze(1)
|
||||||
attention_mask.unsqueeze_(2)
|
attention_mask.unsqueeze_(2)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
block: GPTBigCodeBlock
|
block: GPTBigCodeBlock
|
||||||
for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(
|
||||||
|
zip(self.transformer.h, past_key_values)
|
||||||
|
):
|
||||||
hidden_states, residual, past_key_values[i] = block.decode(
|
hidden_states, residual, past_key_values[i] = block.decode(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
@ -400,7 +490,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
key_length=key_length,
|
key_length=key_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.transformer.ln_f.forward(hidden_states, residual)
|
hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
|
||||||
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -26,6 +26,7 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
|||||||
GPTBigCodeConfig,
|
GPTBigCodeConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InferenceRunnerType(IntEnum):
|
class InferenceRunnerType(IntEnum):
|
||||||
NO_RUNNER = 0
|
NO_RUNNER = 0
|
||||||
# Use the inference runner without cuda graphs.
|
# Use the inference runner without cuda graphs.
|
||||||
@ -38,6 +39,7 @@ class InferenceRunnerType(IntEnum):
|
|||||||
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
|
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
|
||||||
FULL_GRAPH = 3
|
FULL_GRAPH = 3
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
@ -52,7 +54,11 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def upcast_masked_softmax(
|
def upcast_masked_softmax(
|
||||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
|
x: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
mask_value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
softmax_dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
x = x.to(softmax_dtype) * scale
|
x = x.to(softmax_dtype) * scale
|
||||||
@ -105,7 +111,13 @@ def softmax_function(
|
|||||||
|
|
||||||
|
|
||||||
class GPTBigCodeAttention(nn.Module):
|
class GPTBigCodeAttention(nn.Module):
|
||||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
layer_idx=None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mask_value = None
|
self.mask_value = None
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -115,14 +127,27 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
|
|
||||||
# KV caching and padding
|
# KV caching and padding
|
||||||
|
|
||||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device)
|
self.c_attn = nn.Linear(
|
||||||
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device)
|
self.embed_dim,
|
||||||
|
self.embed_dim + 2 * self.head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.c_proj = nn.Linear(
|
||||||
|
self.embed_dim, self.embed_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._get_mask_value")
|
@torch.profiler.record_function("GPTBigCodeAttention._get_mask_value")
|
||||||
def _get_mask_value(self, device, dtype):
|
def _get_mask_value(self, device, dtype):
|
||||||
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
||||||
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
|
if (
|
||||||
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
self.mask_value is None
|
||||||
|
or self.mask_value.dtype != dtype
|
||||||
|
or self.mask_value.device != device
|
||||||
|
):
|
||||||
|
self.mask_value = torch.full(
|
||||||
|
[], torch.finfo(dtype).min, dtype=dtype, device=device
|
||||||
|
)
|
||||||
return self.mask_value
|
return self.mask_value
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._attn")
|
@torch.profiler.record_function("GPTBigCodeAttention._attn")
|
||||||
@ -156,12 +181,16 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
beta = 1
|
beta = 1
|
||||||
else:
|
else:
|
||||||
beta = 0
|
beta = 0
|
||||||
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)
|
attn_weights = torch.baddbmm(
|
||||||
|
attn_weights, query, key, beta=beta, alpha=scale_factor
|
||||||
|
).view(attn_shape)
|
||||||
|
|
||||||
attn_weights = softmax_function(
|
attn_weights = softmax_function(
|
||||||
attn_weights,
|
attn_weights,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype),
|
None
|
||||||
|
if attention_mask is None
|
||||||
|
else self._get_mask_value(attn_weights.device, softmax_dtype),
|
||||||
unscale,
|
unscale,
|
||||||
softmax_dtype,
|
softmax_dtype,
|
||||||
upcast,
|
upcast,
|
||||||
@ -172,7 +201,6 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._attn_flash")
|
@torch.profiler.record_function("GPTBigCodeAttention._attn_flash")
|
||||||
def _attn_flash(self, query, key, value, flash_params):
|
def _attn_flash(self, query, key, value, flash_params):
|
||||||
|
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
||||||
query = query.view(attn_shape)
|
query = query.view(attn_shape)
|
||||||
@ -199,11 +227,12 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches")
|
@torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches")
|
||||||
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
|
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
|
||||||
|
|
||||||
# Convert to standard KV cache format.
|
# Convert to standard KV cache format.
|
||||||
if flash_params is not None:
|
if flash_params is not None:
|
||||||
_, padding_index, batch_size, max_sequence_length = flash_params
|
_, padding_index, batch_size, max_sequence_length = flash_params
|
||||||
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
|
current_kv_cache = pad_input(
|
||||||
|
key_value, padding_index, batch_size, max_sequence_length
|
||||||
|
)
|
||||||
return key_value, (current_kv_cache, max_sequence_length)
|
return key_value, (current_kv_cache, max_sequence_length)
|
||||||
|
|
||||||
current_kv_cache = key_value
|
current_kv_cache = key_value
|
||||||
@ -257,12 +286,16 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
flash_params: Optional[Tuple] = None
|
flash_params: Optional[Tuple] = None,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
query, key_value = self.c_attn(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.head_dim), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
# present = (allocated_kv_cache, key_length)
|
# present = (allocated_kv_cache, key_length)
|
||||||
key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params)
|
key_value, present = self._merge_kv_caches(
|
||||||
|
key_value, layer_past, attention_mask, flash_params
|
||||||
|
)
|
||||||
|
|
||||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
@ -287,15 +320,35 @@ class GPTBigCodeMLP(nn.Module):
|
|||||||
@torch.profiler.record_function("GPTBigCodeMLP.forward")
|
@torch.profiler.record_function("GPTBigCodeMLP.forward")
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
||||||
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
|
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
|
||||||
return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh"))
|
return self.c_proj(
|
||||||
|
nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeBlock(nn.Module):
|
class GPTBigCodeBlock(nn.Module):
|
||||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
layer_idx=None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
self.ln_1 = nn.LayerNorm(
|
||||||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device)
|
config.hidden_size,
|
||||||
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.attn = GPTBigCodeAttention(
|
||||||
|
config, layer_idx=layer_idx, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.ln_2 = nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
self.mlp = GPTBigCodeMLP(config)
|
self.mlp = GPTBigCodeMLP(config)
|
||||||
|
|
||||||
@torch.profiler.record_function("GPTBigCodeBlock.forward")
|
@torch.profiler.record_function("GPTBigCodeBlock.forward")
|
||||||
@ -304,7 +357,7 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
flash_params: Optional[Tuple] = None
|
flash_params: Optional[Tuple] = None,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
|
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
|
||||||
ai = self.ln_1(hidden_states)
|
ai = self.ln_1(hidden_states)
|
||||||
@ -327,7 +380,6 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
return hidden_states, present
|
return hidden_states, present
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -354,7 +406,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
#
|
#
|
||||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||||
module.c_proj.weight.data.normal_(
|
module.c_proj.weight.data.normal_(
|
||||||
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
mean=0.0,
|
||||||
|
std=(
|
||||||
|
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
module.c_proj._is_hf_initialized = True
|
module.c_proj._is_hf_initialized = True
|
||||||
elif isinstance(module, nn.Linear):
|
elif isinstance(module, nn.Linear):
|
||||||
@ -373,16 +428,40 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device)
|
self.wte = nn.Embedding(
|
||||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device)
|
config.vocab_size, config.hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.wpe = nn.Embedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)])
|
self.h = nn.ModuleList(
|
||||||
self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon)
|
[
|
||||||
|
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln_f = nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner)
|
self.inference_runner_type = (
|
||||||
|
InferenceRunnerType.NO_RUNNER
|
||||||
|
) # InferenceRunnerType(config.inference_runner)
|
||||||
|
|
||||||
self.flash_attention = True # config.flash_attention
|
self.flash_attention = True # config.flash_attention
|
||||||
|
|
||||||
@ -402,13 +481,20 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
|
|
||||||
# Causal mask
|
# Causal mask
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device)
|
"bias",
|
||||||
|
torch.empty(
|
||||||
|
(config.max_position_embeddings, config.max_position_embeddings),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask")
|
# @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask")
|
||||||
def _get_causal_mask(self, padding_mask, query_length, key_length):
|
def _get_causal_mask(self, padding_mask, query_length, key_length):
|
||||||
# Self-attention mask.
|
# Self-attention mask.
|
||||||
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
|
attention_mask = self.bias[
|
||||||
|
None, key_length - query_length : key_length, :key_length
|
||||||
|
]
|
||||||
|
|
||||||
if padding_mask is not None:
|
if padding_mask is not None:
|
||||||
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
|
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
|
||||||
@ -416,7 +502,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
)
|
)
|
||||||
pad = -key_length % 8
|
pad = -key_length % 8
|
||||||
if pad > 0:
|
if pad > 0:
|
||||||
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
|
attention_mask = torch.nn.functional.pad(
|
||||||
|
attention_mask, (0, pad), mode="constant", value=False
|
||||||
|
)
|
||||||
|
|
||||||
# (batch_size, query_length, n_heads, key_length)
|
# (batch_size, query_length, n_heads, key_length)
|
||||||
return attention_mask.unsqueeze(2)
|
return attention_mask.unsqueeze(2)
|
||||||
@ -433,7 +521,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
if self.inference_runner is not None and past_key_values is not None:
|
if self.inference_runner is not None and past_key_values is not None:
|
||||||
if self.config.validate_runner_input:
|
if self.config.validate_runner_input:
|
||||||
assert past_key_values is not None
|
assert past_key_values is not None
|
||||||
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values)
|
return self.inference_runner.forward(
|
||||||
|
input_ids, attention_mask, position_ids, past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, query_length = input_ids.shape
|
batch_size, query_length = input_ids.shape
|
||||||
|
|
||||||
@ -444,30 +534,38 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][1]
|
past_length = past_key_values[0][1]
|
||||||
|
|
||||||
|
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
|
||||||
# TODO: Unpad earlier (input ids), support unpadded input?
|
# TODO: Unpad earlier (input ids), support unpadded input?
|
||||||
if flash_attention:
|
if flash_attention:
|
||||||
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input(
|
(
|
||||||
hidden_states, attention_mask
|
hidden_states,
|
||||||
|
padding_index,
|
||||||
|
sequence_lengths,
|
||||||
|
max_sequence_length,
|
||||||
|
) = unpad_input(hidden_states, attention_mask)
|
||||||
|
flash_params = (
|
||||||
|
sequence_lengths,
|
||||||
|
padding_index,
|
||||||
|
batch_size,
|
||||||
|
max_sequence_length,
|
||||||
)
|
)
|
||||||
flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
else:
|
else:
|
||||||
key_length = past_length + query_length
|
key_length = past_length + query_length
|
||||||
# Self-attention mask (padding + causal).
|
# Self-attention mask (padding + causal).
|
||||||
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
|
attention_mask = self._get_causal_mask(
|
||||||
|
attention_mask, query_length, key_length
|
||||||
|
)
|
||||||
flash_params = None
|
flash_params = None
|
||||||
|
|
||||||
presents = []
|
presents = []
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
flash_params=flash_params
|
flash_params=flash_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
@ -476,17 +574,26 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
if flash_attention:
|
if flash_attention:
|
||||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
|
hidden_states = pad_input(
|
||||||
|
hidden_states, padding_index, batch_size, query_length
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states, presents
|
return hidden_states, presents
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
meta = torch.device("meta")
|
meta = torch.device("meta")
|
||||||
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
|
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta)
|
self.lm_head = nn.Linear(
|
||||||
|
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta
|
||||||
|
)
|
||||||
|
|
||||||
self.to_empty(device=device)
|
self.to_empty(device=device)
|
||||||
|
|
||||||
@ -503,12 +610,11 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
predict_all_tokens: bool = True,
|
predict_all_tokens: bool = True,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
|
|
||||||
hidden_states, presents = self.transformer(
|
hidden_states, presents = self.transformer(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids
|
position_ids=position_ids,
|
||||||
)
|
)
|
||||||
# with torch.profiler.record_function("GPTBigCodeForCausalLM.head"):
|
# with torch.profiler.record_function("GPTBigCodeForCausalLM.head"):
|
||||||
if not predict_all_tokens:
|
if not predict_all_tokens:
|
||||||
|
@ -20,14 +20,13 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import dropout_layer_norm
|
|
||||||
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
||||||
GPTBigCodeConfig,
|
GPTBigCodeConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InferenceRunnerType(IntEnum):
|
class InferenceRunnerType(IntEnum):
|
||||||
NO_RUNNER = 0
|
NO_RUNNER = 0
|
||||||
# Use the inference runner without cuda graphs.
|
# Use the inference runner without cuda graphs.
|
||||||
@ -41,7 +40,6 @@ class InferenceRunnerType(IntEnum):
|
|||||||
FULL_GRAPH = 3
|
FULL_GRAPH = 3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
@ -56,7 +54,11 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def upcast_masked_softmax(
|
def upcast_masked_softmax(
|
||||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
|
x: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
mask_value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
softmax_dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
x = x.to(softmax_dtype) * scale
|
x = x.to(softmax_dtype) * scale
|
||||||
@ -108,7 +110,13 @@ def softmax_function(
|
|||||||
|
|
||||||
|
|
||||||
class GPTBigCodeAttention(nn.Module):
|
class GPTBigCodeAttention(nn.Module):
|
||||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
layer_idx=None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mask_value = None
|
self.mask_value = None
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -118,13 +126,26 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
|
|
||||||
# KV caching and padding
|
# KV caching and padding
|
||||||
|
|
||||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device)
|
self.c_attn = nn.Linear(
|
||||||
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device)
|
self.embed_dim,
|
||||||
|
self.embed_dim + 2 * self.head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.c_proj = nn.Linear(
|
||||||
|
self.embed_dim, self.embed_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def _get_mask_value(self, device, dtype):
|
def _get_mask_value(self, device, dtype):
|
||||||
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
||||||
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
|
if (
|
||||||
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
self.mask_value is None
|
||||||
|
or self.mask_value.dtype != dtype
|
||||||
|
or self.mask_value.device != device
|
||||||
|
):
|
||||||
|
self.mask_value = torch.full(
|
||||||
|
[], torch.finfo(dtype).min, dtype=dtype, device=device
|
||||||
|
)
|
||||||
return self.mask_value
|
return self.mask_value
|
||||||
|
|
||||||
def _attn(self, query, key, value, attention_mask):
|
def _attn(self, query, key, value, attention_mask):
|
||||||
@ -157,12 +178,16 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
beta = 1
|
beta = 1
|
||||||
else:
|
else:
|
||||||
beta = 0
|
beta = 0
|
||||||
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)
|
attn_weights = torch.baddbmm(
|
||||||
|
attn_weights, query, key, beta=beta, alpha=scale_factor
|
||||||
|
).view(attn_shape)
|
||||||
|
|
||||||
attn_weights = softmax_function(
|
attn_weights = softmax_function(
|
||||||
attn_weights,
|
attn_weights,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype),
|
None
|
||||||
|
if attention_mask is None
|
||||||
|
else self._get_mask_value(attn_weights.device, softmax_dtype),
|
||||||
unscale,
|
unscale,
|
||||||
softmax_dtype,
|
softmax_dtype,
|
||||||
upcast,
|
upcast,
|
||||||
@ -172,7 +197,6 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def _attn_flash(self, query, key, value, flash_params):
|
def _attn_flash(self, query, key, value, flash_params):
|
||||||
|
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
||||||
query = query.view(attn_shape)
|
query = query.view(attn_shape)
|
||||||
@ -198,11 +222,12 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
|
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
|
||||||
|
|
||||||
# Convert to standard KV cache format.
|
# Convert to standard KV cache format.
|
||||||
if flash_params is not None:
|
if flash_params is not None:
|
||||||
_, padding_index, batch_size, max_sequence_length = flash_params
|
_, padding_index, batch_size, max_sequence_length = flash_params
|
||||||
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
|
current_kv_cache = pad_input(
|
||||||
|
key_value, padding_index, batch_size, max_sequence_length
|
||||||
|
)
|
||||||
return key_value, (current_kv_cache, max_sequence_length)
|
return key_value, (current_kv_cache, max_sequence_length)
|
||||||
|
|
||||||
current_kv_cache = key_value
|
current_kv_cache = key_value
|
||||||
@ -255,12 +280,16 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
flash_params: Optional[Tuple] = None
|
flash_params: Optional[Tuple] = None,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
query, key_value = self.c_attn(hidden_states).split(
|
||||||
|
(self.embed_dim, 2 * self.head_dim), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
# present = (allocated_kv_cache, key_length)
|
# present = (allocated_kv_cache, key_length)
|
||||||
key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params)
|
key_value, present = self._merge_kv_caches(
|
||||||
|
key_value, layer_past, attention_mask, flash_params
|
||||||
|
)
|
||||||
|
|
||||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||||
|
|
||||||
@ -284,15 +313,35 @@ class GPTBigCodeMLP(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
||||||
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
|
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
|
||||||
return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh"))
|
return self.c_proj(
|
||||||
|
nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeBlock(nn.Module):
|
class GPTBigCodeBlock(nn.Module):
|
||||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
layer_idx=None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
self.ln_1 = nn.LayerNorm(
|
||||||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device)
|
config.hidden_size,
|
||||||
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.attn = GPTBigCodeAttention(
|
||||||
|
config, layer_idx=layer_idx, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.ln_2 = nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
self.mlp = GPTBigCodeMLP(config)
|
self.mlp = GPTBigCodeMLP(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -300,7 +349,7 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
layer_past: Optional[torch.Tensor] = None,
|
layer_past: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
flash_params: Optional[Tuple] = None
|
flash_params: Optional[Tuple] = None,
|
||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
attn_output, present = self.attn(
|
attn_output, present = self.attn(
|
||||||
self.ln_1(hidden_states),
|
self.ln_1(hidden_states),
|
||||||
@ -313,7 +362,6 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
return hidden_states, present
|
return hidden_states, present
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -340,7 +388,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
#
|
#
|
||||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||||
module.c_proj.weight.data.normal_(
|
module.c_proj.weight.data.normal_(
|
||||||
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
mean=0.0,
|
||||||
|
std=(
|
||||||
|
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
module.c_proj._is_hf_initialized = True
|
module.c_proj._is_hf_initialized = True
|
||||||
elif isinstance(module, nn.Linear):
|
elif isinstance(module, nn.Linear):
|
||||||
@ -359,16 +410,40 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device)
|
self.wte = nn.Embedding(
|
||||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device)
|
config.vocab_size, config.hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.wpe = nn.Embedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)])
|
self.h = nn.ModuleList(
|
||||||
self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon)
|
[
|
||||||
|
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln_f = nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner)
|
self.inference_runner_type = (
|
||||||
|
InferenceRunnerType.NO_RUNNER
|
||||||
|
) # InferenceRunnerType(config.inference_runner)
|
||||||
|
|
||||||
self.flash_attention = True # config.flash_attention
|
self.flash_attention = True # config.flash_attention
|
||||||
|
|
||||||
@ -388,12 +463,19 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
|
|
||||||
# Causal mask
|
# Causal mask
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device)
|
"bias",
|
||||||
|
torch.empty(
|
||||||
|
(config.max_position_embeddings, config.max_position_embeddings),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_causal_mask(self, padding_mask, query_length, key_length):
|
def _get_causal_mask(self, padding_mask, query_length, key_length):
|
||||||
# Self-attention mask.
|
# Self-attention mask.
|
||||||
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
|
attention_mask = self.bias[
|
||||||
|
None, key_length - query_length : key_length, :key_length
|
||||||
|
]
|
||||||
|
|
||||||
if padding_mask is not None:
|
if padding_mask is not None:
|
||||||
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
|
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
|
||||||
@ -401,7 +483,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
)
|
)
|
||||||
pad = -key_length % 8
|
pad = -key_length % 8
|
||||||
if pad > 0:
|
if pad > 0:
|
||||||
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
|
attention_mask = torch.nn.functional.pad(
|
||||||
|
attention_mask, (0, pad), mode="constant", value=False
|
||||||
|
)
|
||||||
|
|
||||||
# (batch_size, query_length, n_heads, key_length)
|
# (batch_size, query_length, n_heads, key_length)
|
||||||
return attention_mask.unsqueeze(2)
|
return attention_mask.unsqueeze(2)
|
||||||
@ -417,7 +501,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
if self.inference_runner is not None and past_key_values is not None:
|
if self.inference_runner is not None and past_key_values is not None:
|
||||||
if self.config.validate_runner_input:
|
if self.config.validate_runner_input:
|
||||||
assert past_key_values is not None
|
assert past_key_values is not None
|
||||||
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values)
|
return self.inference_runner.forward(
|
||||||
|
input_ids, attention_mask, position_ids, past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, query_length = input_ids.shape
|
batch_size, query_length = input_ids.shape
|
||||||
|
|
||||||
@ -428,30 +514,38 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][1]
|
past_length = past_key_values[0][1]
|
||||||
|
|
||||||
|
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
|
||||||
# TODO: Unpad earlier (input ids), support unpadded input?
|
# TODO: Unpad earlier (input ids), support unpadded input?
|
||||||
if flash_attention:
|
if flash_attention:
|
||||||
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input(
|
(
|
||||||
hidden_states, attention_mask
|
hidden_states,
|
||||||
|
padding_index,
|
||||||
|
sequence_lengths,
|
||||||
|
max_sequence_length,
|
||||||
|
) = unpad_input(hidden_states, attention_mask)
|
||||||
|
flash_params = (
|
||||||
|
sequence_lengths,
|
||||||
|
padding_index,
|
||||||
|
batch_size,
|
||||||
|
max_sequence_length,
|
||||||
)
|
)
|
||||||
flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
else:
|
else:
|
||||||
key_length = past_length + query_length
|
key_length = past_length + query_length
|
||||||
# Self-attention mask (padding + causal).
|
# Self-attention mask (padding + causal).
|
||||||
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
|
attention_mask = self._get_causal_mask(
|
||||||
|
attention_mask, query_length, key_length
|
||||||
|
)
|
||||||
flash_params = None
|
flash_params = None
|
||||||
|
|
||||||
presents = []
|
presents = []
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
flash_params=flash_params
|
flash_params=flash_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
@ -460,17 +554,26 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
if flash_attention:
|
if flash_attention:
|
||||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
|
hidden_states = pad_input(
|
||||||
|
hidden_states, padding_index, batch_size, query_length
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states, presents
|
return hidden_states, presents
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
meta = torch.device("meta")
|
meta = torch.device("meta")
|
||||||
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
|
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta)
|
self.lm_head = nn.Linear(
|
||||||
|
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta
|
||||||
|
)
|
||||||
|
|
||||||
self.to_empty(device=device)
|
self.to_empty(device=device)
|
||||||
|
|
||||||
@ -486,12 +589,11 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
predict_all_tokens: bool = True,
|
predict_all_tokens: bool = True,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
|
|
||||||
hidden_states, presents = self.transformer(
|
hidden_states, presents = self.transformer(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids
|
position_ids=position_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not predict_all_tokens:
|
if not predict_all_tokens:
|
||||||
|
@ -6,8 +6,13 @@ from opentelemetry import trace
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch
|
from text_generation_server.models.vectorized_causal_lm import (
|
||||||
from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeForCausalLM
|
VectorizedCausalLM,
|
||||||
|
VectorizedCausalLMBatch,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import (
|
||||||
|
GPTBigCodeForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -24,7 +29,9 @@ class BigcodeBatch(VectorizedCausalLMBatch):
|
|||||||
layer_kv.data = layer_kv[keep_indices, sequence_slice]
|
layer_kv.data = layer_kv[keep_indices, sequence_slice]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
|
def _concatenate_key_values(
|
||||||
|
cls, batches, start_indices, end_indices, left_indices, max_input_length
|
||||||
|
):
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
batch_size = sum([len(batch.requests) for batch in batches])
|
batch_size = sum([len(batch.requests) for batch in batches])
|
||||||
|
|
||||||
@ -35,13 +42,21 @@ class BigcodeBatch(VectorizedCausalLMBatch):
|
|||||||
past_key_values = []
|
past_key_values = []
|
||||||
for kv_caches in zip(*(batch.past_key_values for batch in batches)):
|
for kv_caches in zip(*(batch.past_key_values for batch in batches)):
|
||||||
key_values, seq_lengths = zip(*kv_caches)
|
key_values, seq_lengths = zip(*kv_caches)
|
||||||
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths))
|
assert all(
|
||||||
|
left_index + seq_length == max_input_length
|
||||||
|
for left_index, seq_length in zip(left_indices, seq_lengths)
|
||||||
|
)
|
||||||
|
|
||||||
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values))
|
allocate_seq_len = max(
|
||||||
|
left_index + key_value.size(1)
|
||||||
|
for left_index, key_value in zip(left_indices, key_values)
|
||||||
|
)
|
||||||
allocate_seq_len += -allocate_seq_len % 8
|
allocate_seq_len += -allocate_seq_len % 8
|
||||||
|
|
||||||
kv_cache = torch.empty(
|
kv_cache = torch.empty(
|
||||||
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device
|
(batch_size, allocate_seq_len, *key_values[0].shape[2:]),
|
||||||
|
dtype=key_values[0].dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
for key_value, start_index, end_index, left_index in zip(
|
for key_value, start_index, end_index, left_index in zip(
|
||||||
key_values,
|
key_values,
|
||||||
@ -49,7 +64,9 @@ class BigcodeBatch(VectorizedCausalLMBatch):
|
|||||||
end_indices,
|
end_indices,
|
||||||
left_indices,
|
left_indices,
|
||||||
):
|
):
|
||||||
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value)
|
kv_cache[start_index:end_index, left_index:max_input_length].copy_(
|
||||||
|
key_value
|
||||||
|
)
|
||||||
# Set padding to zero to avoid propagating nans.
|
# Set padding to zero to avoid propagating nans.
|
||||||
kv_cache[start_index:end_index, :left_index].fill_(0)
|
kv_cache[start_index:end_index, :left_index].fill_(0)
|
||||||
kv_cache[start_index:end_index, max_input_length:].fill_(0)
|
kv_cache[start_index:end_index, max_input_length:].fill_(0)
|
||||||
@ -58,6 +75,7 @@ class BigcodeBatch(VectorizedCausalLMBatch):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class BigcodeCausalLM(VectorizedCausalLM):
|
class BigcodeCausalLM(VectorizedCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -113,7 +131,7 @@ class BigcodeCausalLM(VectorizedCausalLM):
|
|||||||
attention_mask=batch.attention_mask[:, :key_length],
|
attention_mask=batch.attention_mask[:, :key_length],
|
||||||
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
||||||
past_key_values=batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
use_cache=True,
|
predict_all_tokens=batch.details,
|
||||||
)
|
)
|
||||||
next_token_ids, logprobs = batch.next_token_chooser(
|
next_token_ids, logprobs = batch.next_token_chooser(
|
||||||
input_ids, logits, batch.details
|
input_ids, logits, batch.details
|
||||||
@ -128,8 +146,18 @@ class BigcodeCausalLM(VectorizedCausalLM):
|
|||||||
|
|
||||||
def mock_kv_cache(self, batch: BigcodeBatch, dtype: Optional[torch.dtype]):
|
def mock_kv_cache(self, batch: BigcodeBatch, dtype: Optional[torch.dtype]):
|
||||||
allocate_length = batch.max_input_length + -batch.max_input_length % 8
|
allocate_length = batch.max_input_length + -batch.max_input_length % 8
|
||||||
return [(torch.empty(
|
return [
|
||||||
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
|
(
|
||||||
|
torch.empty(
|
||||||
|
[
|
||||||
|
len(batch),
|
||||||
|
allocate_length - 1,
|
||||||
|
2 * self.model.config.n_embd // self.model.config.n_head,
|
||||||
|
],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=batch.input_ids.device,
|
device=batch.input_ids.device,
|
||||||
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)]
|
),
|
||||||
|
batch.max_input_length - 1,
|
||||||
|
)
|
||||||
|
for _ in range(self.model.config.n_layer)
|
||||||
|
]
|
||||||
|
@ -4,10 +4,17 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type, List
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import GPTBigCodeForCausalLM
|
from text_generation_server.models.vectorized_causal_lm import (
|
||||||
|
VectorizedCausalLM,
|
||||||
|
VectorizedCausalLMBatch,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import (
|
||||||
|
GPTBigCodeForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -20,14 +27,35 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
|
|||||||
# Prefill the attention mask for padded key length.
|
# Prefill the attention mask for padded key length.
|
||||||
attention_mask_fill_value = False
|
attention_mask_fill_value = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "Bigcode2Batch":
|
||||||
|
batch = super().from_pb(pb, tokenizer, device)
|
||||||
|
batch.attention_mask[:, batch.max_input_length :].fill_(False)
|
||||||
|
return batch
|
||||||
|
|
||||||
def _filter_kv_caches(self, keep_indices, sequence_slice):
|
def _filter_kv_caches(self, keep_indices, sequence_slice):
|
||||||
if self.past_key_values is not None:
|
if self.past_key_values is not None:
|
||||||
for layer_kv, _ in self.past_key_values:
|
for layer_kv in self.past_key_values:
|
||||||
# Update tensors in-place to allow incremental garbage collection
|
# Update tensors in-place to allow incremental garbage collection
|
||||||
layer_kv.data = layer_kv[keep_indices, sequence_slice]
|
layer_kv.data = layer_kv[keep_indices, sequence_slice]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
|
def concatenate(cls, batches: List["Bigcode2Batch"]) -> "Bigcode2Batch":
|
||||||
|
batch = super().concatenate(batches)
|
||||||
|
# Replace the attention mask with zeros to support padded key length.
|
||||||
|
# They are already filled with ones in super, but duplication is needed to generate the position ids.
|
||||||
|
batch.attention_mask[:, batch.max_input_length :].fill_(False)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _concatenate_key_values(
|
||||||
|
cls, batches, start_indices, end_indices, left_indices, max_input_length
|
||||||
|
):
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
batch_size = sum([len(batch.requests) for batch in batches])
|
batch_size = sum([len(batch.requests) for batch in batches])
|
||||||
|
|
||||||
@ -36,15 +64,19 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
|
|||||||
raise ValueError("Only concatenate prefilled batches")
|
raise ValueError("Only concatenate prefilled batches")
|
||||||
|
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
for kv_caches in zip(*(batch.past_key_values for batch in batches)):
|
for key_values in zip(*(batch.past_key_values for batch in batches)):
|
||||||
key_values, seq_lengths = zip(*kv_caches)
|
allocate_seq_len = max(
|
||||||
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths))
|
left_index + key_value.size(1)
|
||||||
|
for left_index, key_value in zip(left_indices, key_values)
|
||||||
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values))
|
)
|
||||||
allocate_seq_len += - allocate_seq_len % batches[0].pad_key_length_to_multiple
|
allocate_seq_len += (
|
||||||
|
-allocate_seq_len % batches[0].pad_key_length_to_multiple
|
||||||
|
)
|
||||||
|
|
||||||
kv_cache = torch.empty(
|
kv_cache = torch.empty(
|
||||||
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device
|
(batch_size, allocate_seq_len, *key_values[0].shape[2:]),
|
||||||
|
dtype=key_values[0].dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
for key_value, start_index, end_index, left_index in zip(
|
for key_value, start_index, end_index, left_index in zip(
|
||||||
key_values,
|
key_values,
|
||||||
@ -52,16 +84,21 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
|
|||||||
end_indices,
|
end_indices,
|
||||||
left_indices,
|
left_indices,
|
||||||
):
|
):
|
||||||
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value)
|
kv_cache[start_index:end_index, left_index:max_input_length].copy_(
|
||||||
|
key_value
|
||||||
|
)
|
||||||
# Set padding to zero to avoid propagating nans.
|
# Set padding to zero to avoid propagating nans.
|
||||||
kv_cache[start_index:end_index, :left_index].fill_(0)
|
kv_cache[start_index:end_index, :left_index].fill_(0)
|
||||||
kv_cache[start_index:end_index, max_input_length:].fill_(0)
|
kv_cache[start_index:end_index, max_input_length:].fill_(0)
|
||||||
past_key_values.append((kv_cache, max_input_length))
|
past_key_values.append(kv_cache)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class Bigcode2CausalLM(VectorizedCausalLM):
|
class Bigcode2CausalLM(VectorizedCausalLM):
|
||||||
|
model: GPTBigCodeForCausalLM
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -85,9 +122,12 @@ class Bigcode2CausalLM(VectorizedCausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
|
dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
)
|
)
|
||||||
|
model.post_load_weights()
|
||||||
|
|
||||||
tokenizer.pad_token_id = (
|
tokenizer.pad_token_id = (
|
||||||
model.config.pad_token_id
|
model.config.pad_token_id
|
||||||
if model.config.pad_token_id is not None
|
if model.config.pad_token_id is not None
|
||||||
@ -110,25 +150,29 @@ class Bigcode2CausalLM(VectorizedCausalLM):
|
|||||||
key_length = batch.max_input_length
|
key_length = batch.max_input_length
|
||||||
if batch.past_key_values is None:
|
if batch.past_key_values is None:
|
||||||
# Prefill (flash attn, unpadded key length)
|
# Prefill (flash attn, unpadded key length)
|
||||||
batch.pad_key_length_to_multiple=self.model.pad_key_length_to_multiple
|
input_ids = batch.input_ids[:, :key_length]
|
||||||
padded_key_length=key_length
|
logits, batch.past_key_values = self.model.prefill(
|
||||||
query_length=key_length
|
input_ids=input_ids,
|
||||||
|
attention_mask=batch.attention_mask[:, :key_length],
|
||||||
|
position_ids=batch.position_ids[:, :key_length],
|
||||||
|
predict_all_tokens=batch.details,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Decode (fused attn, padded key length)
|
# Decode (fused attn, padded key length)
|
||||||
batch.attention_mask[:, key_length - 1].fill_(True)
|
batch.attention_mask[:, key_length - 1].fill_(True)
|
||||||
padded_key_length=key_length+-key_length%batch.pad_key_length_to_multiple
|
padded_key_length = (
|
||||||
query_length=1
|
key_length + -key_length % batch.pad_key_length_to_multiple
|
||||||
|
)
|
||||||
input_ids = batch.input_ids[:, key_length - query_length : key_length]
|
input_ids = batch.input_ids[:, key_length - 1 : key_length]
|
||||||
# Model Forward
|
# Model Forward
|
||||||
logits, batch.past_key_values = self.model.forward(
|
logits, batch.past_key_values = self.model.decode(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=batch.attention_mask[:, :padded_key_length],
|
attention_mask=batch.attention_mask[:, :padded_key_length],
|
||||||
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
position_ids=batch.position_ids[:, key_length - 1 : key_length],
|
||||||
past_key_values=batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
key_length=key_length,
|
key_length=key_length,
|
||||||
predict_all_tokens=batch.details
|
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_ids, logprobs = batch.next_token_chooser(
|
next_token_ids, logprobs = batch.next_token_chooser(
|
||||||
input_ids, logits, batch.details
|
input_ids, logits, batch.details
|
||||||
)
|
)
|
||||||
@ -141,9 +185,19 @@ class Bigcode2CausalLM(VectorizedCausalLM):
|
|||||||
return next_token_ids, logprobs
|
return next_token_ids, logprobs
|
||||||
|
|
||||||
def mock_kv_cache(self, batch: Bigcode2Batch, dtype: Optional[torch.dtype]):
|
def mock_kv_cache(self, batch: Bigcode2Batch, dtype: Optional[torch.dtype]):
|
||||||
allocate_length=batch.max_input_length+-batch.max_input_length%batch.pad_key_length_to_multiple
|
allocate_length = (
|
||||||
return [(torch.empty(
|
batch.max_input_length
|
||||||
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
|
+ -batch.max_input_length % batch.pad_key_length_to_multiple
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
torch.randn(
|
||||||
|
[
|
||||||
|
len(batch),
|
||||||
|
allocate_length - 1,
|
||||||
|
2 * self.model.config.n_embd // self.model.config.n_head,
|
||||||
|
],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=batch.input_ids.device,
|
device=batch.input_ids.device,
|
||||||
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)]
|
)
|
||||||
|
for _ in range(self.model.config.n_layer)
|
||||||
|
]
|
||||||
|
@ -58,9 +58,6 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
|
|
||||||
kv_cache_seq_dim: int = 2
|
kv_cache_seq_dim: int = 2
|
||||||
|
|
||||||
# Prefill the attention mask for the generated tokens
|
|
||||||
attention_mask_fill_value=True
|
|
||||||
|
|
||||||
# TODO: Get from requests (should these be lists?)
|
# TODO: Get from requests (should these be lists?)
|
||||||
details: bool = os.environ.get("RETURN_DETAILS") is not None
|
details: bool = os.environ.get("RETURN_DETAILS") is not None
|
||||||
generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None
|
generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None
|
||||||
@ -116,7 +113,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"])
|
attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"])
|
||||||
attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value)
|
attention_mask[:, max_input_length:].fill_(True)
|
||||||
|
|
||||||
position_ids = attention_mask.cumsum(-1).sub_(1)
|
position_ids = attention_mask.cumsum(-1).sub_(1)
|
||||||
position_ids[:, :max_input_length].relu_()
|
position_ids[:, :max_input_length].relu_()
|
||||||
@ -271,7 +268,10 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||||
attention_mask[:, :max_input_length].fill_(0)
|
attention_mask[:, :max_input_length].fill_(0)
|
||||||
attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value)
|
attention_mask[:, max_input_length:].fill_(True)
|
||||||
|
|
||||||
|
position_ids = attention_mask.cumsum(-1).sub_(1)
|
||||||
|
position_ids[:, :max_input_length].relu_()
|
||||||
|
|
||||||
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
|
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
|
||||||
# TODO : only needed for prefill
|
# TODO : only needed for prefill
|
||||||
@ -287,16 +287,15 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
batch.input_ids[:, : batch.max_input_length]
|
batch.input_ids[:, : batch.max_input_length]
|
||||||
)
|
)
|
||||||
|
|
||||||
position_ids = attention_mask.cumsum(-1).sub_(1)
|
|
||||||
position_ids[:, :max_input_length].relu_()
|
|
||||||
|
|
||||||
max_tokens = sum(
|
max_tokens = sum(
|
||||||
batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch)
|
batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch)
|
||||||
for batch in batches
|
for batch in batches
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_cache_seq_dim = batches[0].kv_cache_seq_dim
|
kv_cache_seq_dim = batches[0].kv_cache_seq_dim
|
||||||
past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices, max_input_length)
|
past_key_values = cls._concatenate_key_values(
|
||||||
|
batches, start_indices, end_indices, left_indices, max_input_length
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
@ -317,7 +316,9 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
|
def _concatenate_key_values(
|
||||||
|
cls, batches, start_indices, end_indices, left_indices, max_input_length
|
||||||
|
):
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
batch_size = sum([len(batch.requests) for batch in batches])
|
batch_size = sum([len(batch.requests) for batch in batches])
|
||||||
|
|
||||||
@ -386,10 +387,10 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class VectorizedCausalLM(Model):
|
class VectorizedCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -558,23 +559,42 @@ class VectorizedCausalLM(Model):
|
|||||||
|
|
||||||
return generations, next_batch
|
return generations, next_batch
|
||||||
|
|
||||||
def mock_kv_cache(self, batch: VectorizedCausalLMBatch, dtype:Optional[torch.dtype]):
|
def mock_kv_cache(
|
||||||
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
|
self, batch: VectorizedCausalLMBatch, dtype: Optional[torch.dtype]
|
||||||
|
):
|
||||||
|
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
|
||||||
|
GPTBigCodeForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(self.model, GPTBigCodeForCausalLM):
|
if not isinstance(self.model, GPTBigCodeForCausalLM):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return [torch.empty(
|
return [
|
||||||
[len(batch), batch.max_input_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
|
torch.empty(
|
||||||
|
[
|
||||||
|
len(batch),
|
||||||
|
batch.max_input_length - 1,
|
||||||
|
2 * self.model.config.n_embd // self.model.config.n_head,
|
||||||
|
],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=batch.input_ids.device,
|
device=batch.input_ids.device,
|
||||||
) for _ in range(self.model.config.n_layer)]
|
)
|
||||||
|
for _ in range(self.model.config.n_layer)
|
||||||
|
]
|
||||||
|
|
||||||
def fast_forward(self, batch: VectorizedCausalLMBatch, max_input_length: int, cache_dtype:Optional[torch.dtype]):
|
def fast_forward(
|
||||||
|
self,
|
||||||
|
batch: VectorizedCausalLMBatch,
|
||||||
|
max_input_length: int,
|
||||||
|
cache_dtype: Optional[torch.dtype],
|
||||||
|
):
|
||||||
diff = max_input_length - batch.max_input_length
|
diff = max_input_length - batch.max_input_length
|
||||||
batch.input_ids[:, batch.max_input_length:max_input_length].fill_(self.tokenizer.pad_token_id)
|
batch.input_ids[:, batch.max_input_length : max_input_length].fill_(
|
||||||
|
self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
batch.input_lengths = [length + diff for length in batch.input_lengths]
|
batch.input_lengths = [length + diff for length in batch.input_lengths]
|
||||||
batch.max_input_length += diff
|
batch.max_input_length += diff
|
||||||
for stopping_criteria in batch.stopping_criterias:
|
for stopping_criteria in batch.stopping_criterias:
|
||||||
stopping_criteria.current_tokens += diff
|
stopping_criteria.current_tokens += diff
|
||||||
batch.past_key_values = None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype)
|
batch.past_key_values = (
|
||||||
|
None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype)
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user