Fixes and format

This commit is contained in:
Joel Lamy-Poirier 2023-05-25 15:08:52 -04:00
parent 0921fe6a2a
commit a515fbde4c
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
6 changed files with 683 additions and 283 deletions

View File

@ -13,14 +13,13 @@
# limitations under the License. # limitations under the License.
"""PyTorch GPTBigCode model.""" """PyTorch GPTBigCode model."""
import math import math
from typing import Optional, Tuple, Any, Union, List from typing import Optional, Tuple, Any, List
from enum import IntEnum
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from dropout_layer_norm import dropout_layer_norm from dropout_layer_norm import dropout_add_ln_fwd
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
@ -30,10 +29,11 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
GPTBigCodeConfig, GPTBigCodeConfig,
) )
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
# TODO: Validate dimension # TODO: Validate dimension
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
return dropout_layer_norm.dropout_add_ln_fwd( out, residual, *_ = dropout_add_ln_fwd(
hidden_states, hidden_states,
residual, residual,
self.weight, self.weight,
@ -50,18 +50,24 @@ class FastLayerNorm(nn.LayerNorm):
False, False,
False, False,
) )
if residual is None:
residual = hidden_states
return out, residual
class FastLinear(nn.Linear): class FastLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.addmm(self.bias, input, self.weight) return torch.addmm(self.bias, input, self.weight)
class FastLinearNoBias(nn.Linear): class FastLinearNoBias(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.mm(input, self.weight) return torch.mm(input, self.weight)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@torch.jit.script @torch.jit.script
def upcast_masked_softmax( def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float
@ -72,24 +78,33 @@ def upcast_masked_softmax(
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x return x
@torch.jit.script @torch.jit.script
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
x = torch.where(mask, x, mask_value) x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1) x = torch.nn.functional.softmax(x, dim=-1)
return x return x
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
mask_value: torch.Tensor
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype): def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.layer_idx = layer_idx self.layer_idx = layer_idx
# Note: Does not support module dtype conversion.
self.register_buffer("mask_value", torch.empty((), dtype=torch.float32, device="meta"))
self.c_attn = FastLinear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device="meta") self.c_attn = FastLinear(
self.c_proj = FastLinear(self.embed_dim, self.embed_dim, dtype=dtype, device="meta") self.embed_dim,
self.embed_dim + 2 * self.head_dim,
dtype=dtype,
device="meta",
)
self.c_proj = FastLinear(
self.embed_dim, self.embed_dim, dtype=dtype, device="meta"
)
def prefill( def prefill(
self, self,
@ -98,12 +113,18 @@ class GPTBigCodeAttention(nn.Module):
key_length: int, key_length: int,
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
hidden_shape = hidden_states.shape hidden_shape = hidden_states.shape
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) query, key_value = self.c_attn.forward(hidden_states).split(
(self.embed_dim, 2 * self.head_dim), dim=-1
)
query = query.view(hidden_shape[0], self.num_heads, self.head_dim) query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
key, value = key_value.unsqueeze(1).expand(hidden_shape[0], self.num_heads, 2*self.head_dim).split((self.head_dim, self.head_dim), dim=-1) key, value = (
key_value.unsqueeze(1)
.expand(hidden_shape[0], self.num_heads, 2 * self.head_dim)
.split((self.head_dim, self.head_dim), dim=-1)
)
# attn_output: (sum_seq_len, num_heads * head_dim) # attn_output: (sum_seq_len, num_heads * head_dim)
attn_output = flash_attn_unpadded_func( hidden_states = flash_attn_unpadded_func(
query, query,
key, key,
value, value,
@ -116,9 +137,9 @@ class GPTBigCodeAttention(nn.Module):
causal=True, causal=True,
).view(hidden_shape) ).view(hidden_shape)
attn_output = self.c_proj.forward(attn_output) hidden_states = self.c_proj.forward(hidden_states)
return attn_output, key_value return hidden_states, key_value
def decode( def decode(
self, self,
@ -128,7 +149,9 @@ class GPTBigCodeAttention(nn.Module):
batch_size: int, batch_size: int,
key_length: int, key_length: int,
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) query, key_value = self.c_attn.forward(hidden_states).split(
(self.embed_dim, 2 * self.head_dim), dim=-1
)
# Calculate dimensions and recover layer_past # Calculate dimensions and recover layer_past
padded_key_length = attention_mask.size(-1) padded_key_length = attention_mask.size(-1)
@ -142,14 +165,16 @@ class GPTBigCodeAttention(nn.Module):
dtype=key_value.dtype, dtype=key_value.dtype,
device=key_value.device, device=key_value.device,
) )
allocated_kv_cache[:, :key_length-1].copy_(layer_past[:, :key_length-1]) allocated_kv_cache[:, : key_length - 1].copy_(
layer_past[:, : key_length - 1]
)
# Nans in `value` can propagate through the matrix multiplication, # Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.) # so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_() allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
layer_past = allocated_kv_cache layer_past = allocated_kv_cache
# Copy the new values. # Copy the new values.
layer_past[:, key_length-1:key_length].copy_(key_value) layer_past[:, key_length - 1].copy_(key_value)
key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1) key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
@ -158,21 +183,27 @@ class GPTBigCodeAttention(nn.Module):
unscale = self.layer_idx + 1 if upcast else 1 unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale**-1 / self.head_dim**0.5 scale_factor = unscale**-1 / self.head_dim**0.5
# TODO: No need to unsqueeze?
hidden_states = torch.baddbmm( hidden_states = torch.baddbmm(
torch.empty((batch_size, self.num_heads, padded_key_length), device=query.device, dtype=query.dtype), torch.empty(
(batch_size, self.num_heads, padded_key_length),
device=query.device,
dtype=query.dtype,
),
query.view(batch_size, self.num_heads, self.head_dim), query.view(batch_size, self.num_heads, self.head_dim),
key.transpose(-1, -2), key.transpose(-1, -2),
beta=0, beta=0,
alpha=scale_factor alpha=scale_factor,
).unsqueeze_(1) ).unsqueeze_(1)
if self.mask_value is None or self.mask_value.device != hidden_states.device:
self.mask_value = torch.full([], torch.finfo(torch.float32).min, dtype=torch.float32, device=hidden_states.device)
if upcast: if upcast:
hidden_states = upcast_masked_softmax(hidden_states, attention_mask, self.mask_value, unscale) hidden_states = upcast_masked_softmax(
hidden_states, attention_mask, self.mask_value, unscale
)
else: else:
hidden_states = masked_softmax(hidden_states, attention_mask, self.mask_value) hidden_states = masked_softmax(
hidden_states, attention_mask, self.mask_value
)
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape) hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
@ -180,6 +211,7 @@ class GPTBigCodeAttention(nn.Module):
return hidden_states, layer_past return hidden_states, layer_past
class GPTBigCodeMLP(nn.Module): class GPTBigCodeMLP(nn.Module):
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict) # TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype): def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
@ -189,12 +221,23 @@ class GPTBigCodeMLP(nn.Module):
self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta") self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta")
self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta") self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta")
class GPTBigCodeBlock(nn.Module): class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype): def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype):
super().__init__() super().__init__()
self.ln_1 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta") self.ln_1 = FastLayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device="meta",
)
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype) self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
self.ln_2 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta") self.ln_2 = FastLayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device="meta",
)
self.mlp = GPTBigCodeMLP(config, dtype=dtype) self.mlp = GPTBigCodeMLP(config, dtype=dtype)
def prefill( def prefill(
@ -211,7 +254,9 @@ class GPTBigCodeBlock(nn.Module):
key_length=key_length, key_length=key_length,
) )
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")) hidden_states = self.mlp.c_proj.forward(
nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")
)
return hidden_states, residual, present return hidden_states, residual, present
def decode( def decode(
@ -232,7 +277,9 @@ class GPTBigCodeBlock(nn.Module):
key_length=key_length, key_length=key_length,
) )
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")) hidden_states = self.mlp.c_proj.forward(
nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")
)
return hidden_states, residual, present return hidden_states, residual, present
@ -252,11 +299,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, GPTBigCodeModel): if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
module.bias.fill_(True).tril_()
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
if isinstance(module, GPTBigCodeAttention):
module.mask_value.fill_(torch.finfo(torch.float32).min)
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
@ -264,7 +307,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
# #
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
module.c_proj.weight.data.normal_( module.c_proj.weight.data.normal_(
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) mean=0.0,
std=(
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
),
) )
module.c_proj._is_hf_initialized = True module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear): elif isinstance(module, nn.Linear):
@ -287,39 +333,59 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype): def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype):
super().__init__(config) super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device="meta") self.wte = nn.Embedding(
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device="meta") config.vocab_size, config.hidden_size, dtype=dtype, device="meta"
)
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) for i in range(config.num_hidden_layers)]) self.wpe = nn.Embedding(
self.ln_f = FastLayerNorm(config.hidden_size, dtype=dtype, device="meta", eps=config.layer_norm_epsilon) config.max_position_embeddings,
config.hidden_size,
# Causal mask dtype=dtype,
self.register_buffer( device="meta",
"causal_mask", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device="meta")
) )
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): self.h = nn.ModuleList(
pad_key_length_to_multiple=8 [
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm(
config.hidden_size,
dtype=dtype,
device="meta",
eps=config.layer_norm_epsilon,
)
def __init__(self, config, dtype:torch.dtype, device:torch.device=torch.device("cuda")):
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
def __init__(
self, config, dtype: torch.dtype, device: torch.device = torch.device("cuda")
):
super().__init__(config) super().__init__(config)
if device.type != "cuda": if device.type != "cuda":
raise NotImplementedError(f"Device {device} not supported") raise NotImplementedError(f"Device {device} not supported")
self.transformer = GPTBigCodeModel(config, dtype=dtype) self.transformer = GPTBigCodeModel(config, dtype=dtype)
self.lm_head = FastLinearNoBias(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta") self.lm_head = FastLinearNoBias(
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta"
)
# Causal mask
self.causal_mask = torch.ones(
(config.max_position_embeddings, config.max_position_embeddings),
dtype=torch.bool,
device=device,
).tril_()
self.mask_value = torch.full(
(), torch.finfo(torch.float32).min, dtype=torch.float32, device=device
)
self.to_empty(device=device) self.to_empty(device=device)
self._apply=self._apply_not_allowed
# Initialize weights and apply final processing # Initialize weights and apply final processing
# TODO: Skip?
self.post_init() self.post_init()
def _apply_not_allowed(self):
# Dtype or device conversion would break the model.
raise NotImplementedError("Device or dtype conversion not supported!")
def prefill( def prefill(
self, self,
*, *,
@ -330,11 +396,15 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
) -> Tuple: ) -> Tuple:
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids) hidden_states = self.transformer.wte.forward(
input_ids
) + self.transformer.wpe.forward(position_ids)
# Prefill (flash attn) # Prefill (flash attn)
# TODO: Unpad earlier (input ids)? # TODO: Unpad earlier (input ids)?
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(hidden_states, attention_mask) hidden_states, padding_index, sequence_lengths, key_length = unpad_input(
hidden_states, attention_mask
)
assert key_length == query_length assert key_length == query_length
residual = None residual = None
@ -347,23 +417,39 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
key_length=query_length, key_length=query_length,
) )
past_key_values.append(pad_input(key_value, padding_index, batch_size, query_length)) past_key_values.append(
pad_input(key_value, padding_index, batch_size, query_length)
)
hidden_states = self.transformer.ln_f.forward(hidden_states, residual) hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible. # Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
del residual del residual
if predict_all_tokens: if predict_all_tokens:
hidden_states = self.lm_head.forward(hidden_states) hidden_states = self.lm_head.forward(hidden_states)
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) hidden_states = pad_input(
hidden_states, padding_index, batch_size, query_length
)
else: else:
# TODO: Index directly instead # TODO: Index directly instead
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)[:, -1] hidden_states = pad_input(
hidden_states, padding_index, batch_size, query_length
)[:, -1]
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
return hidden_states, past_key_values return hidden_states, past_key_values
def post_load_weights(self):
layer: GPTBigCodeBlock
for layer in self.transformer.h:
layer.attn.mask_value = self.mask_value
layer.attn.c_attn.weight.t_()
layer.attn.c_proj.weight.t_()
layer.mlp.c_fc.weight.t_()
layer.mlp.c_proj.weight.t_()
self.lm_head.weight.t_()
def decode( def decode(
self, self,
*, *,
@ -373,24 +459,28 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
key_length: int, key_length: int,
) -> Tuple: ) -> Tuple:
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
assert query_length == 1 assert query_length == 1
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids) hidden_states = self.transformer.wte.forward(
input_ids
) + self.transformer.wpe.forward(position_ids)
# Standardize shape to (batch_size, hidden_size) # Standardize shape to (batch_size, hidden_size)
hidden_states.squeeze_(1) hidden_states.squeeze_(1)
# Self-attention mask (padding + causal). # Self-attention mask (padding + causal).
# TODO: Avoid unsqueeze # TODO: Avoid unsqueeze
attention_mask = self.transformer.causal_mask[None, key_length - 1: key_length, attention_mask = self.causal_mask[
:key_length] * attention_mask.unsqueeze(1) None, key_length - 1 : key_length, : attention_mask.size(-1)
] * attention_mask.unsqueeze(1)
attention_mask.unsqueeze_(2) attention_mask.unsqueeze_(2)
residual = None residual = None
block: GPTBigCodeBlock block: GPTBigCodeBlock
for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): for i, (block, layer_past) in enumerate(
zip(self.transformer.h, past_key_values)
):
hidden_states, residual, past_key_values[i] = block.decode( hidden_states, residual, past_key_values[i] = block.decode(
hidden_states, hidden_states,
residual=residual, residual=residual,
@ -400,7 +490,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
key_length=key_length, key_length=key_length,
) )
hidden_states = self.transformer.ln_f.forward(hidden_states, residual) hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual)
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
return hidden_states, past_key_values return hidden_states, past_key_values

View File

@ -26,6 +26,7 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
GPTBigCodeConfig, GPTBigCodeConfig,
) )
class InferenceRunnerType(IntEnum): class InferenceRunnerType(IntEnum):
NO_RUNNER = 0 NO_RUNNER = 0
# Use the inference runner without cuda graphs. # Use the inference runner without cuda graphs.
@ -38,6 +39,7 @@ class InferenceRunnerType(IntEnum):
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
FULL_GRAPH = 3 FULL_GRAPH = 3
try: try:
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
@ -52,7 +54,11 @@ logger = logging.get_logger(__name__)
@torch.jit.script @torch.jit.script
def upcast_masked_softmax( def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype x: torch.Tensor,
mask: torch.Tensor,
mask_value: torch.Tensor,
scale: float,
softmax_dtype: torch.dtype,
): ):
input_dtype = x.dtype input_dtype = x.dtype
x = x.to(softmax_dtype) * scale x = x.to(softmax_dtype) * scale
@ -105,7 +111,13 @@ def softmax_function(
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
layer_idx=None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__() super().__init__()
self.mask_value = None self.mask_value = None
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -115,14 +127,27 @@ class GPTBigCodeAttention(nn.Module):
# KV caching and padding # KV caching and padding
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device) self.c_attn = nn.Linear(
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device) self.embed_dim,
self.embed_dim + 2 * self.head_dim,
dtype=dtype,
device=device,
)
self.c_proj = nn.Linear(
self.embed_dim, self.embed_dim, dtype=dtype, device=device
)
@torch.profiler.record_function("GPTBigCodeAttention._get_mask_value") @torch.profiler.record_function("GPTBigCodeAttention._get_mask_value")
def _get_mask_value(self, device, dtype): def _get_mask_value(self, device, dtype):
# torch.where expects a tensor. We use a cache to avoid recreating it every time. # torch.where expects a tensor. We use a cache to avoid recreating it every time.
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: if (
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) self.mask_value is None
or self.mask_value.dtype != dtype
or self.mask_value.device != device
):
self.mask_value = torch.full(
[], torch.finfo(dtype).min, dtype=dtype, device=device
)
return self.mask_value return self.mask_value
@torch.profiler.record_function("GPTBigCodeAttention._attn") @torch.profiler.record_function("GPTBigCodeAttention._attn")
@ -156,12 +181,16 @@ class GPTBigCodeAttention(nn.Module):
beta = 1 beta = 1
else: else:
beta = 0 beta = 0
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) attn_weights = torch.baddbmm(
attn_weights, query, key, beta=beta, alpha=scale_factor
).view(attn_shape)
attn_weights = softmax_function( attn_weights = softmax_function(
attn_weights, attn_weights,
attention_mask, attention_mask,
None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), None
if attention_mask is None
else self._get_mask_value(attn_weights.device, softmax_dtype),
unscale, unscale,
softmax_dtype, softmax_dtype,
upcast, upcast,
@ -172,7 +201,6 @@ class GPTBigCodeAttention(nn.Module):
@torch.profiler.record_function("GPTBigCodeAttention._attn_flash") @torch.profiler.record_function("GPTBigCodeAttention._attn_flash")
def _attn_flash(self, query, key, value, flash_params): def _attn_flash(self, query, key, value, flash_params):
query_shape = query.shape query_shape = query.shape
attn_shape = query_shape[0], self.num_heads, self.head_dim attn_shape = query_shape[0], self.num_heads, self.head_dim
query = query.view(attn_shape) query = query.view(attn_shape)
@ -199,11 +227,12 @@ class GPTBigCodeAttention(nn.Module):
@torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches") @torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches")
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
# Convert to standard KV cache format. # Convert to standard KV cache format.
if flash_params is not None: if flash_params is not None:
_, padding_index, batch_size, max_sequence_length = flash_params _, padding_index, batch_size, max_sequence_length = flash_params
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) current_kv_cache = pad_input(
key_value, padding_index, batch_size, max_sequence_length
)
return key_value, (current_kv_cache, max_sequence_length) return key_value, (current_kv_cache, max_sequence_length)
current_kv_cache = key_value current_kv_cache = key_value
@ -257,12 +286,16 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None, layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
flash_params: Optional[Tuple] = None flash_params: Optional[Tuple] = None,
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) query, key_value = self.c_attn(hidden_states).split(
(self.embed_dim, 2 * self.head_dim), dim=-1
)
# present = (allocated_kv_cache, key_length) # present = (allocated_kv_cache, key_length)
key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params) key_value, present = self._merge_kv_caches(
key_value, layer_past, attention_mask, flash_params
)
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
@ -287,15 +320,35 @@ class GPTBigCodeMLP(nn.Module):
@torch.profiler.record_function("GPTBigCodeMLP.forward") @torch.profiler.record_function("GPTBigCodeMLP.forward")
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")) return self.c_proj(
nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")
)
class GPTBigCodeBlock(nn.Module): class GPTBigCodeBlock(nn.Module):
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
layer_idx=None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__() super().__init__()
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) self.ln_1 = nn.LayerNorm(
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device) config.hidden_size,
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) eps=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
self.attn = GPTBigCodeAttention(
config, layer_idx=layer_idx, dtype=dtype, device=device
)
self.ln_2 = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
self.mlp = GPTBigCodeMLP(config) self.mlp = GPTBigCodeMLP(config)
@torch.profiler.record_function("GPTBigCodeBlock.forward") @torch.profiler.record_function("GPTBigCodeBlock.forward")
@ -304,7 +357,7 @@ class GPTBigCodeBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None, layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
flash_params: Optional[Tuple] = None flash_params: Optional[Tuple] = None,
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
with torch.profiler.record_function("GPTBigCodeAttention.ln"): with torch.profiler.record_function("GPTBigCodeAttention.ln"):
ai = self.ln_1(hidden_states) ai = self.ln_1(hidden_states)
@ -327,7 +380,6 @@ class GPTBigCodeBlock(nn.Module):
return hidden_states, present return hidden_states, present
class GPTBigCodePreTrainedModel(PreTrainedModel): class GPTBigCodePreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -354,7 +406,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
# #
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
module.c_proj.weight.data.normal_( module.c_proj.weight.data.normal_(
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) mean=0.0,
std=(
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
),
) )
module.c_proj._is_hf_initialized = True module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear): elif isinstance(module, nn.Linear):
@ -373,16 +428,40 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
class GPTBigCodeModel(GPTBigCodePreTrainedModel): class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__(config) super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device) self.wte = nn.Embedding(
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device) config.vocab_size, config.hidden_size, dtype=dtype, device=device
)
self.wpe = nn.Embedding(
config.max_position_embeddings,
config.hidden_size,
dtype=dtype,
device=device,
)
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)]) self.h = nn.ModuleList(
self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon) [
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = nn.LayerNorm(
config.hidden_size,
dtype=dtype,
device=device,
eps=config.layer_norm_epsilon,
)
self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner) self.inference_runner_type = (
InferenceRunnerType.NO_RUNNER
) # InferenceRunnerType(config.inference_runner)
self.flash_attention = True # config.flash_attention self.flash_attention = True # config.flash_attention
@ -402,13 +481,20 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# Causal mask # Causal mask
self.register_buffer( self.register_buffer(
"bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device) "bias",
torch.empty(
(config.max_position_embeddings, config.max_position_embeddings),
dtype=torch.bool,
device=device,
),
) )
# @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask") # @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask")
def _get_causal_mask(self, padding_mask, query_length, key_length): def _get_causal_mask(self, padding_mask, query_length, key_length):
# Self-attention mask. # Self-attention mask.
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] attention_mask = self.bias[
None, key_length - query_length : key_length, :key_length
]
if padding_mask is not None: if padding_mask is not None:
attention_mask = attention_mask * padding_mask.unsqueeze(1).to( attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
@ -416,7 +502,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
) )
pad = -key_length % 8 pad = -key_length % 8
if pad > 0: if pad > 0:
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) attention_mask = torch.nn.functional.pad(
attention_mask, (0, pad), mode="constant", value=False
)
# (batch_size, query_length, n_heads, key_length) # (batch_size, query_length, n_heads, key_length)
return attention_mask.unsqueeze(2) return attention_mask.unsqueeze(2)
@ -433,7 +521,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
if self.inference_runner is not None and past_key_values is not None: if self.inference_runner is not None and past_key_values is not None:
if self.config.validate_runner_input: if self.config.validate_runner_input:
assert past_key_values is not None assert past_key_values is not None
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) return self.inference_runner.forward(
input_ids, attention_mask, position_ids, past_key_values
)
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
@ -444,30 +534,38 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
else: else:
past_length = past_key_values[0][1] past_length = past_key_values[0][1]
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
# TODO: Unpad earlier (input ids), support unpadded input? # TODO: Unpad earlier (input ids), support unpadded input?
if flash_attention: if flash_attention:
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( (
hidden_states, attention_mask hidden_states,
padding_index,
sequence_lengths,
max_sequence_length,
) = unpad_input(hidden_states, attention_mask)
flash_params = (
sequence_lengths,
padding_index,
batch_size,
max_sequence_length,
) )
flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
attention_mask = None attention_mask = None
else: else:
key_length = past_length + query_length key_length = past_length + query_length
# Self-attention mask (padding + causal). # Self-attention mask (padding + causal).
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) attention_mask = self._get_causal_mask(
attention_mask, query_length, key_length
)
flash_params = None flash_params = None
presents = [] presents = []
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
flash_params=flash_params flash_params=flash_params,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@ -476,17 +574,26 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
if flash_attention: if flash_attention:
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) hidden_states = pad_input(
hidden_states, padding_index, batch_size, query_length
)
return hidden_states, presents return hidden_states, presents
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__(config) super().__init__(config)
meta = torch.device("meta") meta = torch.device("meta")
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta) self.lm_head = nn.Linear(
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta
)
self.to_empty(device=device) self.to_empty(device=device)
@ -503,12 +610,11 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
predict_all_tokens: bool = True, predict_all_tokens: bool = True,
) -> Tuple: ) -> Tuple:
hidden_states, presents = self.transformer( hidden_states, presents = self.transformer(
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids position_ids=position_ids,
) )
# with torch.profiler.record_function("GPTBigCodeForCausalLM.head"): # with torch.profiler.record_function("GPTBigCodeForCausalLM.head"):
if not predict_all_tokens: if not predict_all_tokens:

View File

@ -20,14 +20,13 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
import dropout_layer_norm
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
GPTBigCodeConfig, GPTBigCodeConfig,
) )
class InferenceRunnerType(IntEnum): class InferenceRunnerType(IntEnum):
NO_RUNNER = 0 NO_RUNNER = 0
# Use the inference runner without cuda graphs. # Use the inference runner without cuda graphs.
@ -41,7 +40,6 @@ class InferenceRunnerType(IntEnum):
FULL_GRAPH = 3 FULL_GRAPH = 3
try: try:
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
@ -56,7 +54,11 @@ logger = logging.get_logger(__name__)
@torch.jit.script @torch.jit.script
def upcast_masked_softmax( def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype x: torch.Tensor,
mask: torch.Tensor,
mask_value: torch.Tensor,
scale: float,
softmax_dtype: torch.dtype,
): ):
input_dtype = x.dtype input_dtype = x.dtype
x = x.to(softmax_dtype) * scale x = x.to(softmax_dtype) * scale
@ -108,7 +110,13 @@ def softmax_function(
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
layer_idx=None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__() super().__init__()
self.mask_value = None self.mask_value = None
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -118,13 +126,26 @@ class GPTBigCodeAttention(nn.Module):
# KV caching and padding # KV caching and padding
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device) self.c_attn = nn.Linear(
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device) self.embed_dim,
self.embed_dim + 2 * self.head_dim,
dtype=dtype,
device=device,
)
self.c_proj = nn.Linear(
self.embed_dim, self.embed_dim, dtype=dtype, device=device
)
def _get_mask_value(self, device, dtype): def _get_mask_value(self, device, dtype):
# torch.where expects a tensor. We use a cache to avoid recreating it every time. # torch.where expects a tensor. We use a cache to avoid recreating it every time.
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: if (
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) self.mask_value is None
or self.mask_value.dtype != dtype
or self.mask_value.device != device
):
self.mask_value = torch.full(
[], torch.finfo(dtype).min, dtype=dtype, device=device
)
return self.mask_value return self.mask_value
def _attn(self, query, key, value, attention_mask): def _attn(self, query, key, value, attention_mask):
@ -157,12 +178,16 @@ class GPTBigCodeAttention(nn.Module):
beta = 1 beta = 1
else: else:
beta = 0 beta = 0
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) attn_weights = torch.baddbmm(
attn_weights, query, key, beta=beta, alpha=scale_factor
).view(attn_shape)
attn_weights = softmax_function( attn_weights = softmax_function(
attn_weights, attn_weights,
attention_mask, attention_mask,
None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), None
if attention_mask is None
else self._get_mask_value(attn_weights.device, softmax_dtype),
unscale, unscale,
softmax_dtype, softmax_dtype,
upcast, upcast,
@ -172,7 +197,6 @@ class GPTBigCodeAttention(nn.Module):
return attn_output return attn_output
def _attn_flash(self, query, key, value, flash_params): def _attn_flash(self, query, key, value, flash_params):
query_shape = query.shape query_shape = query.shape
attn_shape = query_shape[0], self.num_heads, self.head_dim attn_shape = query_shape[0], self.num_heads, self.head_dim
query = query.view(attn_shape) query = query.view(attn_shape)
@ -198,11 +222,12 @@ class GPTBigCodeAttention(nn.Module):
return attn_output return attn_output
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
# Convert to standard KV cache format. # Convert to standard KV cache format.
if flash_params is not None: if flash_params is not None:
_, padding_index, batch_size, max_sequence_length = flash_params _, padding_index, batch_size, max_sequence_length = flash_params
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) current_kv_cache = pad_input(
key_value, padding_index, batch_size, max_sequence_length
)
return key_value, (current_kv_cache, max_sequence_length) return key_value, (current_kv_cache, max_sequence_length)
current_kv_cache = key_value current_kv_cache = key_value
@ -255,12 +280,16 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None, layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
flash_params: Optional[Tuple] = None flash_params: Optional[Tuple] = None,
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) query, key_value = self.c_attn(hidden_states).split(
(self.embed_dim, 2 * self.head_dim), dim=-1
)
# present = (allocated_kv_cache, key_length) # present = (allocated_kv_cache, key_length)
key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params) key_value, present = self._merge_kv_caches(
key_value, layer_past, attention_mask, flash_params
)
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
@ -284,15 +313,35 @@ class GPTBigCodeMLP(nn.Module):
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")) return self.c_proj(
nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")
)
class GPTBigCodeBlock(nn.Module): class GPTBigCodeBlock(nn.Module):
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
layer_idx=None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__() super().__init__()
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) self.ln_1 = nn.LayerNorm(
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device) config.hidden_size,
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) eps=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
self.attn = GPTBigCodeAttention(
config, layer_idx=layer_idx, dtype=dtype, device=device
)
self.ln_2 = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
self.mlp = GPTBigCodeMLP(config) self.mlp = GPTBigCodeMLP(config)
def forward( def forward(
@ -300,7 +349,7 @@ class GPTBigCodeBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None, layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
flash_params: Optional[Tuple] = None flash_params: Optional[Tuple] = None,
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
attn_output, present = self.attn( attn_output, present = self.attn(
self.ln_1(hidden_states), self.ln_1(hidden_states),
@ -313,7 +362,6 @@ class GPTBigCodeBlock(nn.Module):
return hidden_states, present return hidden_states, present
class GPTBigCodePreTrainedModel(PreTrainedModel): class GPTBigCodePreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -340,7 +388,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
# #
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
module.c_proj.weight.data.normal_( module.c_proj.weight.data.normal_(
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) mean=0.0,
std=(
self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
),
) )
module.c_proj._is_hf_initialized = True module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear): elif isinstance(module, nn.Linear):
@ -359,16 +410,40 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
class GPTBigCodeModel(GPTBigCodePreTrainedModel): class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__(config) super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device) self.wte = nn.Embedding(
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device) config.vocab_size, config.hidden_size, dtype=dtype, device=device
)
self.wpe = nn.Embedding(
config.max_position_embeddings,
config.hidden_size,
dtype=dtype,
device=device,
)
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)]) self.h = nn.ModuleList(
self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon) [
GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = nn.LayerNorm(
config.hidden_size,
dtype=dtype,
device=device,
eps=config.layer_norm_epsilon,
)
self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner) self.inference_runner_type = (
InferenceRunnerType.NO_RUNNER
) # InferenceRunnerType(config.inference_runner)
self.flash_attention = True # config.flash_attention self.flash_attention = True # config.flash_attention
@ -388,12 +463,19 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# Causal mask # Causal mask
self.register_buffer( self.register_buffer(
"bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device) "bias",
torch.empty(
(config.max_position_embeddings, config.max_position_embeddings),
dtype=torch.bool,
device=device,
),
) )
def _get_causal_mask(self, padding_mask, query_length, key_length): def _get_causal_mask(self, padding_mask, query_length, key_length):
# Self-attention mask. # Self-attention mask.
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] attention_mask = self.bias[
None, key_length - query_length : key_length, :key_length
]
if padding_mask is not None: if padding_mask is not None:
attention_mask = attention_mask * padding_mask.unsqueeze(1).to( attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
@ -401,7 +483,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
) )
pad = -key_length % 8 pad = -key_length % 8
if pad > 0: if pad > 0:
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) attention_mask = torch.nn.functional.pad(
attention_mask, (0, pad), mode="constant", value=False
)
# (batch_size, query_length, n_heads, key_length) # (batch_size, query_length, n_heads, key_length)
return attention_mask.unsqueeze(2) return attention_mask.unsqueeze(2)
@ -417,7 +501,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
if self.inference_runner is not None and past_key_values is not None: if self.inference_runner is not None and past_key_values is not None:
if self.config.validate_runner_input: if self.config.validate_runner_input:
assert past_key_values is not None assert past_key_values is not None
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) return self.inference_runner.forward(
input_ids, attention_mask, position_ids, past_key_values
)
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
@ -428,30 +514,38 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
else: else:
past_length = past_key_values[0][1] past_length = past_key_values[0][1]
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
# TODO: Unpad earlier (input ids), support unpadded input? # TODO: Unpad earlier (input ids), support unpadded input?
if flash_attention: if flash_attention:
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( (
hidden_states, attention_mask hidden_states,
padding_index,
sequence_lengths,
max_sequence_length,
) = unpad_input(hidden_states, attention_mask)
flash_params = (
sequence_lengths,
padding_index,
batch_size,
max_sequence_length,
) )
flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
attention_mask = None attention_mask = None
else: else:
key_length = past_length + query_length key_length = past_length + query_length
# Self-attention mask (padding + causal). # Self-attention mask (padding + causal).
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) attention_mask = self._get_causal_mask(
attention_mask, query_length, key_length
)
flash_params = None flash_params = None
presents = [] presents = []
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
flash_params=flash_params flash_params=flash_params,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@ -460,17 +554,26 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
if flash_attention: if flash_attention:
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) hidden_states = pad_input(
hidden_states, padding_index, batch_size, query_length
)
return hidden_states, presents return hidden_states, presents
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): def __init__(
self,
config,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__(config) super().__init__(config)
meta = torch.device("meta") meta = torch.device("meta")
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta) self.lm_head = nn.Linear(
config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta
)
self.to_empty(device=device) self.to_empty(device=device)
@ -486,12 +589,11 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
predict_all_tokens: bool = True, predict_all_tokens: bool = True,
) -> Tuple: ) -> Tuple:
hidden_states, presents = self.transformer( hidden_states, presents = self.transformer(
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids position_ids=position_ids,
) )
if not predict_all_tokens: if not predict_all_tokens:

View File

@ -6,8 +6,13 @@ from opentelemetry import trace
from transformers import AutoTokenizer from transformers import AutoTokenizer
from typing import Optional, Type from typing import Optional, Type
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch from text_generation_server.models.vectorized_causal_lm import (
from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeForCausalLM VectorizedCausalLM,
VectorizedCausalLMBatch,
)
from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import (
GPTBigCodeForCausalLM,
)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -24,7 +29,9 @@ class BigcodeBatch(VectorizedCausalLMBatch):
layer_kv.data = layer_kv[keep_indices, sequence_slice] layer_kv.data = layer_kv[keep_indices, sequence_slice]
@classmethod @classmethod
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): def _concatenate_key_values(
cls, batches, start_indices, end_indices, left_indices, max_input_length
):
device = batches[0].input_ids.device device = batches[0].input_ids.device
batch_size = sum([len(batch.requests) for batch in batches]) batch_size = sum([len(batch.requests) for batch in batches])
@ -35,13 +42,21 @@ class BigcodeBatch(VectorizedCausalLMBatch):
past_key_values = [] past_key_values = []
for kv_caches in zip(*(batch.past_key_values for batch in batches)): for kv_caches in zip(*(batch.past_key_values for batch in batches)):
key_values, seq_lengths = zip(*kv_caches) key_values, seq_lengths = zip(*kv_caches)
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths)) assert all(
left_index + seq_length == max_input_length
for left_index, seq_length in zip(left_indices, seq_lengths)
)
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values)) allocate_seq_len = max(
left_index + key_value.size(1)
for left_index, key_value in zip(left_indices, key_values)
)
allocate_seq_len += -allocate_seq_len % 8 allocate_seq_len += -allocate_seq_len % 8
kv_cache = torch.empty( kv_cache = torch.empty(
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device (batch_size, allocate_seq_len, *key_values[0].shape[2:]),
dtype=key_values[0].dtype,
device=device,
) )
for key_value, start_index, end_index, left_index in zip( for key_value, start_index, end_index, left_index in zip(
key_values, key_values,
@ -49,7 +64,9 @@ class BigcodeBatch(VectorizedCausalLMBatch):
end_indices, end_indices,
left_indices, left_indices,
): ):
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value) kv_cache[start_index:end_index, left_index:max_input_length].copy_(
key_value
)
# Set padding to zero to avoid propagating nans. # Set padding to zero to avoid propagating nans.
kv_cache[start_index:end_index, :left_index].fill_(0) kv_cache[start_index:end_index, :left_index].fill_(0)
kv_cache[start_index:end_index, max_input_length:].fill_(0) kv_cache[start_index:end_index, max_input_length:].fill_(0)
@ -58,6 +75,7 @@ class BigcodeBatch(VectorizedCausalLMBatch):
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
class BigcodeCausalLM(VectorizedCausalLM): class BigcodeCausalLM(VectorizedCausalLM):
def __init__( def __init__(
self, self,
@ -113,7 +131,7 @@ class BigcodeCausalLM(VectorizedCausalLM):
attention_mask=batch.attention_mask[:, :key_length], attention_mask=batch.attention_mask[:, :key_length],
position_ids=batch.position_ids[:, key_length - query_length : key_length], position_ids=batch.position_ids[:, key_length - query_length : key_length],
past_key_values=batch.past_key_values, past_key_values=batch.past_key_values,
use_cache=True, predict_all_tokens=batch.details,
) )
next_token_ids, logprobs = batch.next_token_chooser( next_token_ids, logprobs = batch.next_token_chooser(
input_ids, logits, batch.details input_ids, logits, batch.details
@ -128,8 +146,18 @@ class BigcodeCausalLM(VectorizedCausalLM):
def mock_kv_cache(self, batch: BigcodeBatch, dtype: Optional[torch.dtype]): def mock_kv_cache(self, batch: BigcodeBatch, dtype: Optional[torch.dtype]):
allocate_length = batch.max_input_length + -batch.max_input_length % 8 allocate_length = batch.max_input_length + -batch.max_input_length % 8
return [(torch.empty( return [
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], (
torch.empty(
[
len(batch),
allocate_length - 1,
2 * self.model.config.n_embd // self.model.config.n_head,
],
dtype=dtype, dtype=dtype,
device=batch.input_ids.device, device=batch.input_ids.device,
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)] ),
batch.max_input_length - 1,
)
for _ in range(self.model.config.n_layer)
]

View File

@ -4,10 +4,17 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer from transformers import AutoTokenizer
from typing import Optional, Type from typing import Optional, Type, List
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import GPTBigCodeForCausalLM from text_generation_server.models.vectorized_causal_lm import (
VectorizedCausalLM,
VectorizedCausalLMBatch,
)
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import (
GPTBigCodeForCausalLM,
)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -20,14 +27,35 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
# Prefill the attention mask for padded key length. # Prefill the attention mask for padded key length.
attention_mask_fill_value = False attention_mask_fill_value = False
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Bigcode2Batch":
batch = super().from_pb(pb, tokenizer, device)
batch.attention_mask[:, batch.max_input_length :].fill_(False)
return batch
def _filter_kv_caches(self, keep_indices, sequence_slice): def _filter_kv_caches(self, keep_indices, sequence_slice):
if self.past_key_values is not None: if self.past_key_values is not None:
for layer_kv, _ in self.past_key_values: for layer_kv in self.past_key_values:
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection
layer_kv.data = layer_kv[keep_indices, sequence_slice] layer_kv.data = layer_kv[keep_indices, sequence_slice]
@classmethod @classmethod
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): def concatenate(cls, batches: List["Bigcode2Batch"]) -> "Bigcode2Batch":
batch = super().concatenate(batches)
# Replace the attention mask with zeros to support padded key length.
# They are already filled with ones in super, but duplication is needed to generate the position ids.
batch.attention_mask[:, batch.max_input_length :].fill_(False)
return batch
@classmethod
def _concatenate_key_values(
cls, batches, start_indices, end_indices, left_indices, max_input_length
):
device = batches[0].input_ids.device device = batches[0].input_ids.device
batch_size = sum([len(batch.requests) for batch in batches]) batch_size = sum([len(batch.requests) for batch in batches])
@ -36,15 +64,19 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
raise ValueError("Only concatenate prefilled batches") raise ValueError("Only concatenate prefilled batches")
past_key_values = [] past_key_values = []
for kv_caches in zip(*(batch.past_key_values for batch in batches)): for key_values in zip(*(batch.past_key_values for batch in batches)):
key_values, seq_lengths = zip(*kv_caches) allocate_seq_len = max(
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths)) left_index + key_value.size(1)
for left_index, key_value in zip(left_indices, key_values)
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values)) )
allocate_seq_len += - allocate_seq_len % batches[0].pad_key_length_to_multiple allocate_seq_len += (
-allocate_seq_len % batches[0].pad_key_length_to_multiple
)
kv_cache = torch.empty( kv_cache = torch.empty(
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device (batch_size, allocate_seq_len, *key_values[0].shape[2:]),
dtype=key_values[0].dtype,
device=device,
) )
for key_value, start_index, end_index, left_index in zip( for key_value, start_index, end_index, left_index in zip(
key_values, key_values,
@ -52,16 +84,21 @@ class Bigcode2Batch(VectorizedCausalLMBatch):
end_indices, end_indices,
left_indices, left_indices,
): ):
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value) kv_cache[start_index:end_index, left_index:max_input_length].copy_(
key_value
)
# Set padding to zero to avoid propagating nans. # Set padding to zero to avoid propagating nans.
kv_cache[start_index:end_index, :left_index].fill_(0) kv_cache[start_index:end_index, :left_index].fill_(0)
kv_cache[start_index:end_index, max_input_length:].fill_(0) kv_cache[start_index:end_index, max_input_length:].fill_(0)
past_key_values.append((kv_cache, max_input_length)) past_key_values.append(kv_cache)
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
class Bigcode2CausalLM(VectorizedCausalLM): class Bigcode2CausalLM(VectorizedCausalLM):
model: GPTBigCodeForCausalLM
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -85,9 +122,12 @@ class Bigcode2CausalLM(VectorizedCausalLM):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
) )
model.post_load_weights()
tokenizer.pad_token_id = ( tokenizer.pad_token_id = (
model.config.pad_token_id model.config.pad_token_id
if model.config.pad_token_id is not None if model.config.pad_token_id is not None
@ -110,25 +150,29 @@ class Bigcode2CausalLM(VectorizedCausalLM):
key_length = batch.max_input_length key_length = batch.max_input_length
if batch.past_key_values is None: if batch.past_key_values is None:
# Prefill (flash attn, unpadded key length) # Prefill (flash attn, unpadded key length)
batch.pad_key_length_to_multiple=self.model.pad_key_length_to_multiple input_ids = batch.input_ids[:, :key_length]
padded_key_length=key_length logits, batch.past_key_values = self.model.prefill(
query_length=key_length input_ids=input_ids,
attention_mask=batch.attention_mask[:, :key_length],
position_ids=batch.position_ids[:, :key_length],
predict_all_tokens=batch.details,
)
else: else:
# Decode (fused attn, padded key length) # Decode (fused attn, padded key length)
batch.attention_mask[:, key_length - 1].fill_(True) batch.attention_mask[:, key_length - 1].fill_(True)
padded_key_length=key_length+-key_length%batch.pad_key_length_to_multiple padded_key_length = (
query_length=1 key_length + -key_length % batch.pad_key_length_to_multiple
)
input_ids = batch.input_ids[:, key_length - query_length : key_length] input_ids = batch.input_ids[:, key_length - 1 : key_length]
# Model Forward # Model Forward
logits, batch.past_key_values = self.model.forward( logits, batch.past_key_values = self.model.decode(
input_ids=input_ids, input_ids=input_ids,
attention_mask=batch.attention_mask[:, :padded_key_length], attention_mask=batch.attention_mask[:, :padded_key_length],
position_ids=batch.position_ids[:, key_length - query_length : key_length], position_ids=batch.position_ids[:, key_length - 1 : key_length],
past_key_values=batch.past_key_values, past_key_values=batch.past_key_values,
key_length=key_length, key_length=key_length,
predict_all_tokens=batch.details
) )
next_token_ids, logprobs = batch.next_token_chooser( next_token_ids, logprobs = batch.next_token_chooser(
input_ids, logits, batch.details input_ids, logits, batch.details
) )
@ -141,9 +185,19 @@ class Bigcode2CausalLM(VectorizedCausalLM):
return next_token_ids, logprobs return next_token_ids, logprobs
def mock_kv_cache(self, batch: Bigcode2Batch, dtype: Optional[torch.dtype]): def mock_kv_cache(self, batch: Bigcode2Batch, dtype: Optional[torch.dtype]):
allocate_length=batch.max_input_length+-batch.max_input_length%batch.pad_key_length_to_multiple allocate_length = (
return [(torch.empty( batch.max_input_length
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], + -batch.max_input_length % batch.pad_key_length_to_multiple
)
return [
torch.randn(
[
len(batch),
allocate_length - 1,
2 * self.model.config.n_embd // self.model.config.n_head,
],
dtype=dtype, dtype=dtype,
device=batch.input_ids.device, device=batch.input_ids.device,
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)] )
for _ in range(self.model.config.n_layer)
]

View File

@ -58,9 +58,6 @@ class VectorizedCausalLMBatch(Batch):
kv_cache_seq_dim: int = 2 kv_cache_seq_dim: int = 2
# Prefill the attention mask for the generated tokens
attention_mask_fill_value=True
# TODO: Get from requests (should these be lists?) # TODO: Get from requests (should these be lists?)
details: bool = os.environ.get("RETURN_DETAILS") is not None details: bool = os.environ.get("RETURN_DETAILS") is not None
generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None
@ -116,7 +113,7 @@ class VectorizedCausalLMBatch(Batch):
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"]) attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"])
attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value) attention_mask[:, max_input_length:].fill_(True)
position_ids = attention_mask.cumsum(-1).sub_(1) position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_() position_ids[:, :max_input_length].relu_()
@ -271,7 +268,10 @@ class VectorizedCausalLMBatch(Batch):
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
attention_mask[:, :max_input_length].fill_(0) attention_mask[:, :max_input_length].fill_(0)
attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value) attention_mask[:, max_input_length:].fill_(True)
position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_()
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device) input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
# TODO : only needed for prefill # TODO : only needed for prefill
@ -287,16 +287,15 @@ class VectorizedCausalLMBatch(Batch):
batch.input_ids[:, : batch.max_input_length] batch.input_ids[:, : batch.max_input_length]
) )
position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_()
max_tokens = sum( max_tokens = sum(
batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch)
for batch in batches for batch in batches
) )
kv_cache_seq_dim = batches[0].kv_cache_seq_dim kv_cache_seq_dim = batches[0].kv_cache_seq_dim
past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices, max_input_length) past_key_values = cls._concatenate_key_values(
batches, start_indices, end_indices, left_indices, max_input_length
)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
@ -317,7 +316,9 @@ class VectorizedCausalLMBatch(Batch):
) )
@classmethod @classmethod
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): def _concatenate_key_values(
cls, batches, start_indices, end_indices, left_indices, max_input_length
):
device = batches[0].input_ids.device device = batches[0].input_ids.device
batch_size = sum([len(batch.requests) for batch in batches]) batch_size = sum([len(batch.requests) for batch in batches])
@ -386,10 +387,10 @@ class VectorizedCausalLMBatch(Batch):
return return
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
class VectorizedCausalLM(Model): class VectorizedCausalLM(Model):
def __init__( def __init__(
self, self,
@ -558,23 +559,42 @@ class VectorizedCausalLM(Model):
return generations, next_batch return generations, next_batch
def mock_kv_cache(self, batch: VectorizedCausalLMBatch, dtype:Optional[torch.dtype]): def mock_kv_cache(
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM self, batch: VectorizedCausalLMBatch, dtype: Optional[torch.dtype]
):
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeForCausalLM,
)
if not isinstance(self.model, GPTBigCodeForCausalLM): if not isinstance(self.model, GPTBigCodeForCausalLM):
raise NotImplementedError() raise NotImplementedError()
return [torch.empty( return [
[len(batch), batch.max_input_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], torch.empty(
[
len(batch),
batch.max_input_length - 1,
2 * self.model.config.n_embd // self.model.config.n_head,
],
dtype=dtype, dtype=dtype,
device=batch.input_ids.device, device=batch.input_ids.device,
) for _ in range(self.model.config.n_layer)] )
for _ in range(self.model.config.n_layer)
]
def fast_forward(self, batch: VectorizedCausalLMBatch, max_input_length: int, cache_dtype:Optional[torch.dtype]): def fast_forward(
self,
batch: VectorizedCausalLMBatch,
max_input_length: int,
cache_dtype: Optional[torch.dtype],
):
diff = max_input_length - batch.max_input_length diff = max_input_length - batch.max_input_length
batch.input_ids[:, batch.max_input_length:max_input_length].fill_(self.tokenizer.pad_token_id) batch.input_ids[:, batch.max_input_length : max_input_length].fill_(
self.tokenizer.pad_token_id
)
batch.input_lengths = [length + diff for length in batch.input_lengths] batch.input_lengths = [length + diff for length in batch.input_lengths]
batch.max_input_length += diff batch.max_input_length += diff
for stopping_criteria in batch.stopping_criterias: for stopping_criteria in batch.stopping_criterias:
stopping_criteria.current_tokens += diff stopping_criteria.current_tokens += diff
batch.past_key_values = None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype) batch.past_key_values = (
None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype)
)