Optimized model

This commit is contained in:
Joel Lamy-Poirier 2023-05-24 19:51:53 -04:00
parent a6dd19b042
commit 0921fe6a2a
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
10 changed files with 1485 additions and 537 deletions

View File

@ -17,6 +17,8 @@ from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_bigcode import BigcodeCausalLM
from text_generation_server.models.gpt_bigcode2 import Bigcode2CausalLM
try: try:
if torch.cuda.is_available() and os.environ.get("NO_FLASH_ATTENTION") is None: if torch.cuda.is_available() and os.environ.get("NO_FLASH_ATTENTION") is None:
@ -91,10 +93,25 @@ torch.backends.cudnn.allow_tf32 = True
# Disable gradients # Disable gradients
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
model_type_map={
"flash":FlashSantacoder,
"santa":SantaCoder,
"causal":CausalLM,
"vector":VectorizedCausalLM,
"bigcode":BigcodeCausalLM,
"bigcode2":Bigcode2CausalLM,
}
def get_model( def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str] model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
) -> Model: ) -> Model:
model_type=os.environ.get("MODEL_TYPE")
if model_type is not None:
if model_type not in model_type_map:
raise NotImplementedError(model_type)
return model_type_map[model_type](model_id, revision, quantize=quantize)
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
if sharded: if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize) return GalacticaSharded(model_id, revision, quantize=quantize)
@ -167,8 +184,6 @@ def get_model(
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
if os.environ.get("VECTORIZED_LM") is not None:
return VectorizedCausalLM(model_id, revision, quantize=quantize)
return CausalLM(model_id, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(model_id, revision, quantize=quantize) return Seq2SeqLM(model_id, revision, quantize=quantize)

View File

@ -0,0 +1,406 @@
# coding=utf-8
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch GPTBigCode model."""
import math
from typing import Optional, Tuple, Any, Union, List
from enum import IntEnum
import torch
import torch.utils.checkpoint
from torch import nn
from dropout_layer_norm import dropout_layer_norm
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
GPTBigCodeConfig,
)
class FastLayerNorm(nn.LayerNorm):
# TODO: Validate dimension
def forward(self, hidden_states, residual=None):
return dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
class FastLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.addmm(self.bias, input, self.weight)
class FastLinearNoBias(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.mm(input, self.weight)
logger = logging.get_logger(__name__)
@torch.jit.script
def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float
):
input_dtype = x.dtype
x = x.to(torch.float32) * scale
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1)
return x
class GPTBigCodeAttention(nn.Module):
def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.layer_idx = layer_idx
# Note: Does not support module dtype conversion.
self.register_buffer("mask_value", torch.empty((), dtype=torch.float32, device="meta"))
self.c_attn = FastLinear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device="meta")
self.c_proj = FastLinear(self.embed_dim, self.embed_dim, dtype=dtype, device="meta")
def prefill(
self,
hidden_states: torch.Tensor,
sequence_lengths,
key_length:int,
) -> Tuple[torch.Tensor, Any]:
hidden_shape = hidden_states.shape
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
key, value = key_value.unsqueeze(1).expand(hidden_shape[0], self.num_heads, 2*self.head_dim).split((self.head_dim, self.head_dim), dim=-1)
# attn_output: (sum_seq_len, num_heads * head_dim)
attn_output = flash_attn_unpadded_func(
query,
key,
value,
sequence_lengths,
sequence_lengths,
key_length,
key_length,
0.0,
softmax_scale=self.head_dim**-0.5,
causal=True,
).view(hidden_shape)
attn_output = self.c_proj.forward(attn_output)
return attn_output, key_value
def decode(
self,
hidden_states: torch.Tensor,
layer_past: torch.Tensor,
attention_mask: torch.Tensor,
batch_size:int,
key_length:int,
) -> Tuple[torch.Tensor, Any]:
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
# Calculate dimensions and recover layer_past
padded_key_length = attention_mask.size(-1)
allocated_key_length=layer_past.size(-2)
# TODO: Allow pre-allocation with size > padded_key_length
if padded_key_length > allocated_key_length:
# Re-allocate kv cache and copy last value
allocated_kv_cache = torch.empty(
[batch_size, padded_key_length, 2 * self.head_dim],
dtype=key_value.dtype,
device=key_value.device,
)
allocated_kv_cache[:, :key_length-1].copy_(layer_past[:, :key_length-1])
# Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
layer_past=allocated_kv_cache
# Copy the new values.
layer_past[:, key_length-1:key_length].copy_(key_value)
key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
# TODO: Upcasting needed for bf16?
upcast = query.dtype != torch.float32
unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale ** -1 / self.head_dim ** 0.5
hidden_states = torch.baddbmm(
torch.empty((batch_size, self.num_heads, padded_key_length), device=query.device, dtype=query.dtype),
query.view(batch_size, self.num_heads, self.head_dim),
key.transpose(-1, -2),
beta=0,
alpha=scale_factor
).unsqueeze_(1)
if self.mask_value is None or self.mask_value.device != hidden_states.device:
self.mask_value = torch.full([], torch.finfo(torch.float32).min, dtype=torch.float32, device=hidden_states.device)
if upcast:
hidden_states = upcast_masked_softmax(hidden_states, attention_mask, self.mask_value, unscale)
else:
hidden_states = masked_softmax(hidden_states, attention_mask, self.mask_value)
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
hidden_states = self.c_proj.forward(hidden_states)
return hidden_states, layer_past
class GPTBigCodeMLP(nn.Module):
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype):
super().__init__()
embed_dim = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta")
self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta")
class GPTBigCodeBlock(nn.Module):
def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype):
super().__init__()
self.ln_1 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta")
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
self.ln_2 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta")
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
def prefill(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
sequence_lengths,
key_length:int,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
hidden_states, present = self.attn.prefill(
hidden_states,
sequence_lengths=sequence_lengths,
key_length=key_length,
)
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh"))
return hidden_states, residual, present
def decode(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
layer_past: torch.Tensor,
attention_mask: torch.Tensor,
batch_size:int,
key_length:int,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
hidden_states, present = self.attn.decode(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
batch_size=batch_size,
key_length=key_length,
)
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh"))
return hidden_states, residual, present
class GPTBigCodePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GPTBigCodeConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = False
_no_split_modules = ["GPTBigCodeBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, GPTBigCodeModel):
module.bias.fill_(True).tril_()
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
if isinstance(module, GPTBigCodeAttention):
module.mask_value.fill_(torch.finfo(torch.float32).min)
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
module.c_proj.weight.data.normal_(
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
)
module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict)
def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype):
super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device="meta")
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device="meta")
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) for i in range(config.num_hidden_layers)])
self.ln_f = FastLayerNorm(config.hidden_size, dtype=dtype, device="meta", eps=config.layer_norm_epsilon)
# Causal mask
self.register_buffer(
"causal_mask", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device="meta")
)
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
pad_key_length_to_multiple=8
def __init__(self, config, dtype:torch.dtype, device:torch.device=torch.device("cuda")):
super().__init__(config)
if device.type!="cuda":
raise NotImplementedError(f"Device {device} not supported")
self.transformer = GPTBigCodeModel(config, dtype=dtype)
self.lm_head = FastLinearNoBias(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta")
self.to_empty(device=device)
self._apply=self._apply_not_allowed
# Initialize weights and apply final processing
self.post_init()
def _apply_not_allowed(self):
# Dtype or device conversion would break the model.
raise NotImplementedError("Device or dtype conversion not supported!")
def prefill(
self,
*,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
position_ids: torch.Tensor,
predict_all_tokens: bool=True,
) -> Tuple:
batch_size, query_length = input_ids.shape
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids)
# Prefill (flash attn)
# TODO: Unpad earlier (input ids)?
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(hidden_states, attention_mask)
assert key_length==query_length
residual = None
past_key_values = []
block:GPTBigCodeBlock
for block in self.transformer.h:
hidden_states, residual, key_value = block.prefill(
hidden_states,
residual=residual,
sequence_lengths=sequence_lengths,
key_length=query_length,
)
past_key_values.append(pad_input(key_value, padding_index, batch_size, query_length))
hidden_states = self.transformer.ln_f.forward(hidden_states, residual)
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
del residual
if predict_all_tokens:
hidden_states = self.lm_head.forward(hidden_states)
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
else:
# TODO: Index directly instead
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)[:, -1]
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
return hidden_states, past_key_values
def decode(
self,
*,
input_ids: torch.Tensor,
past_key_values: List[torch.Tensor],
attention_mask: [torch.Tensor],
position_ids: torch.Tensor,
key_length:int,
) -> Tuple:
batch_size, query_length = input_ids.shape
assert query_length == 1
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids)
# Standardize shape to (batch_size, hidden_size)
hidden_states.squeeze_(1)
# Self-attention mask (padding + causal).
# TODO: Avoid unsqueeze
attention_mask = self.transformer.causal_mask[None, key_length - 1: key_length,
:key_length] * attention_mask.unsqueeze(1)
attention_mask.unsqueeze_(2)
residual = None
block:GPTBigCodeBlock
for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)):
hidden_states, residual, past_key_values[i] = block.decode(
hidden_states,
residual=residual,
layer_past=layer_past,
attention_mask=attention_mask,
batch_size=batch_size,
key_length=key_length,
)
hidden_states = self.transformer.ln_f.forward(hidden_states, residual)
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
return hidden_states, past_key_values

View File

@ -0,0 +1,520 @@
# coding=utf-8
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch GPTBigCode model."""
import math
from typing import Optional, Tuple, Any, Union, List
from enum import IntEnum
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
GPTBigCodeConfig,
)
class InferenceRunnerType(IntEnum):
NO_RUNNER = 0
# Use the inference runner without cuda graphs.
BASE_RUNNER = 1
# Use cuda graphs in the inference runner. Leave out the attention which has a variable shape.
# This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models.
PARTIAL_GRAPH = 2
# Turn the whole model into a cuda graph. One graph for each sequence length.
# Note: only useful for small batches and models, graphs take some time to generate, flaky.
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
FULL_GRAPH = 3
try:
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
pad_input = None
unpad_input = None
logger = logging.get_logger(__name__)
@torch.jit.script
def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
):
input_dtype = x.dtype
x = x.to(softmax_dtype) * scale
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
input_dtype = x.dtype
x = x.to(softmax_dtype) * scale
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1)
return x
@torch.profiler.record_function("softmax_function")
def softmax_function(
x: torch.Tensor,
mask: torch.Tensor,
mask_value: torch.Tensor,
scale: float,
softmax_dtype: torch.dtype,
upcast: bool = True,
):
"""
This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case
needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting
and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely
inefficient.
"""
#assert x.size(-1) % 8 == 0
if upcast:
if mask is None:
return upcast_softmax(x, scale, softmax_dtype)
else:
return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
else:
if mask is None:
return torch.nn.functional.softmax(x, dim=-1)
else:
return masked_softmax(x, mask, mask_value)
class GPTBigCodeAttention(nn.Module):
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
super().__init__()
self.mask_value = None
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.layer_idx = layer_idx
# KV caching and padding
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device)
@torch.profiler.record_function("GPTBigCodeAttention._get_mask_value")
def _get_mask_value(self, device, dtype):
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
return self.mask_value
@torch.profiler.record_function("GPTBigCodeAttention._attn")
def _attn(self, query, key, value, attention_mask):
softmax_dtype = torch.float32
upcast = query.dtype != softmax_dtype
unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale**-1 / self.head_dim**0.5
# (batch_size, query_length, num_heads * head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key_length = key.size(-2)
key = key.transpose(-1, -2)
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
if query.device.type == "cpu":
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
# but the fix has not been released as of pytorch version 2.0.0.
attn_weights.zero_()
beta = 1
else:
beta = 0
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)
attn_weights = softmax_function(
attn_weights,
attention_mask,
None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype),
unscale,
softmax_dtype,
upcast,
)
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
return attn_output
@torch.profiler.record_function("GPTBigCodeAttention._attn_flash")
def _attn_flash(self, query, key, value, flash_params):
query_shape = query.shape
attn_shape = query_shape[0], self.num_heads, self.head_dim
query = query.view(attn_shape)
key = key.unsqueeze(1).expand(attn_shape)
value = value.unsqueeze(1).expand(attn_shape)
sequence_lengths, padding_index, _, max_sequence_length = flash_params
# attn_output: (sum_seq_len, num_heads * head_dim)
attn_output = flash_attn_unpadded_func(
query,
key,
value,
sequence_lengths,
sequence_lengths,
max_sequence_length,
max_sequence_length,
0.0,
softmax_scale=self.head_dim**-0.5,
causal=True,
).view(query_shape)
return attn_output
@torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches")
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
# Convert to standard KV cache format.
if flash_params is not None:
_, padding_index, batch_size, max_sequence_length = flash_params
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
return key_value, (current_kv_cache, max_sequence_length)
current_kv_cache = key_value
# Calculate dimensions and recover layer_past
batch_size = current_kv_cache.size(0)
query_length = current_kv_cache.size(-2)
if layer_past is None:
allocated_kv_cache, last_key_length = None, 0
last_kv_cache = None
key_length = query_length
allocated_key_length = key_length
else:
allocated_kv_cache, last_key_length = layer_past
last_kv_cache = allocated_kv_cache[:, :last_key_length]
key_length = query_length + last_key_length
allocated_key_length = allocated_kv_cache.size(-2)
padded_key_length = attention_mask.size(-1)
# Re-allocate kv cache and copy last value
if padded_key_length > allocated_key_length:
allocated_kv_cache = torch.empty(
[batch_size, padded_key_length, 2 * self.head_dim],
dtype=current_kv_cache.dtype,
device=current_kv_cache.device,
)
if layer_past is not None:
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
if padded_key_length > key_length:
# Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
# Copy the new values.
if padded_key_length > allocated_key_length or layer_past is not None:
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
# Use the merged KV cache.
# Not needed when layer_past is None but frees some memory.
key_value = padded_kv_cache
if allocated_kv_cache is None:
allocated_kv_cache = current_kv_cache
present = allocated_kv_cache, key_length
return key_value, present
@torch.profiler.record_function("GPTBigCodeAttention.forward")
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
flash_params: Optional[Tuple] = None
) -> Tuple[torch.Tensor, Any]:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
# present = (allocated_kv_cache, key_length)
key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params)
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
if flash_params is None:
attn_output=self._attn(query, key, value, attention_mask)
else:
attn_output=self._attn_flash(query, key, value, flash_params)
attn_output = self.c_proj(attn_output)
return attn_output, present
class GPTBigCodeMLP(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
self.c_fc = nn.Linear(embed_dim, inner_dim)
self.c_proj = nn.Linear(inner_dim, embed_dim)
@torch.profiler.record_function("GPTBigCodeMLP.forward")
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh"))
class GPTBigCodeBlock(nn.Module):
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
super().__init__()
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device)
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
self.mlp = GPTBigCodeMLP(config)
@torch.profiler.record_function("GPTBigCodeBlock.forward")
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
flash_params: Optional[Tuple] = None
) -> Tuple[torch.Tensor, Any]:
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
ai=self.ln_1(hidden_states)
attn_output, present = self.attn(
ai,
layer_past=layer_past,
attention_mask=attention_mask,
flash_params=flash_params,
)
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
hidden_states.add_(attn_output)
with torch.profiler.record_function("GPTBigCodeAttention.dummy"):
pass
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
ai=self.ln_2(hidden_states)
ai=self.mlp(ai)
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
hidden_states.add_(ai)
return hidden_states, present
class GPTBigCodePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GPTBigCodeConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = False
_no_split_modules = ["GPTBigCodeBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, GPTBigCodeModel):
module.bias.fill_(True).tril_()
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
module.c_proj.weight.data.normal_(
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
)
module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device)
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon)
self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner)
self.flash_attention = True #config.flash_attention
if self.flash_attention:
if flash_attn_unpadded_func is None:
raise RuntimeError(
"Flash Attention requires `flash_attn` and `einops`. "
"To install, run `pip install flash-attn einops`."
)
if self.inference_runner_type == InferenceRunnerType.NO_RUNNER:
self.inference_runner = None
else:
from .inference_runner import GPTBigCodeInferenceRunner
self.inference_runner = GPTBigCodeInferenceRunner(config, self)
# Causal mask
self.register_buffer(
"bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device)
)
#@torch.profiler.record_function("GPTBigCodeModel._get_causal_mask")
def _get_causal_mask(self, padding_mask, query_length, key_length):
# Self-attention mask.
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
if padding_mask is not None:
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
dtype=torch.bool, device=attention_mask.device
)
pad = -key_length % 8
if pad > 0:
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
# (batch_size, query_length, n_heads, key_length)
return attention_mask.unsqueeze(2)
#@torch.profiler.record_function("GPTBigCodeModel.forward")
def forward(
self,
*,
input_ids: torch.Tensor,
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: torch.Tensor,
) -> Tuple:
if self.inference_runner is not None and past_key_values is not None:
if self.config.validate_runner_input:
assert past_key_values is not None
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values)
batch_size, query_length = input_ids.shape
flash_attention = self.flash_attention and past_key_values is None
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][1]
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
# TODO: Unpad earlier (input ids), support unpadded input?
if flash_attention:
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input(
hidden_states, attention_mask
)
flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
attention_mask=None
else:
key_length = past_length + query_length
# Self-attention mask (padding + causal).
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
flash_params=None
presents = []
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
flash_params=flash_params
)
hidden_states = outputs[0]
presents.append(outputs[1])
hidden_states = self.ln_f(hidden_states)
if flash_attention:
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
return hidden_states, presents
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
super().__init__(config)
meta=torch.device("meta")
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta)
self.to_empty(device=device)
# Initialize weights and apply final processing
self.post_init()
#@torch.profiler.record_function("GPTBigCodeForCausalLM.forward")
def forward(
self,
*,
input_ids: torch.Tensor,
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: torch.Tensor,
predict_all_tokens: bool=True,
) -> Tuple:
hidden_states, presents=self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids
)
#with torch.profiler.record_function("GPTBigCodeForCausalLM.head"):
if not predict_all_tokens:
# We only care about the last token.
hidden_states = hidden_states[:, -1:]
lm_logits = self.lm_head(hidden_states)
return lm_logits, presents

View File

@ -13,25 +13,34 @@
# limitations under the License. # limitations under the License.
"""PyTorch GPTBigCode model.""" """PyTorch GPTBigCode model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple, Any, Union, List
from enum import IntEnum
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN import dropout_layer_norm
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
GPTBigCodeConfig, GPTBigCodeConfig,
InferenceRunnerType,
) )
class InferenceRunnerType(IntEnum):
NO_RUNNER = 0
# Use the inference runner without cuda graphs.
BASE_RUNNER = 1
# Use cuda graphs in the inference runner. Leave out the attention which has a variable shape.
# This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models.
PARTIAL_GRAPH = 2
# Turn the whole model into a cuda graph. One graph for each sequence length.
# Note: only useful for small batches and models, graphs take some time to generate, flaky.
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
FULL_GRAPH = 3
try: try:
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
@ -44,6 +53,8 @@ except ImportError:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@torch.jit.script
def upcast_masked_softmax( def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
): ):
@ -55,12 +66,6 @@ def upcast_masked_softmax(
@torch.jit.script @torch.jit.script
def upcast_masked_softmax_fused(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
):
return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
input_dtype = x.dtype input_dtype = x.dtype
x = x.to(softmax_dtype) * scale x = x.to(softmax_dtype) * scale
@ -69,21 +74,12 @@ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
@torch.jit.script @torch.jit.script
def upcast_softmax_fused(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
return upcast_softmax(x, scale, softmax_dtype)
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
x = torch.where(mask, x, mask_value) x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1) x = torch.nn.functional.softmax(x, dim=-1)
return x return x
@torch.jit.script
def masked_softmax_fused(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
return masked_softmax(x, mask, mask_value)
def softmax_function( def softmax_function(
x: torch.Tensor, x: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
@ -91,78 +87,39 @@ def softmax_function(
scale: float, scale: float,
softmax_dtype: torch.dtype, softmax_dtype: torch.dtype,
upcast: bool = True, upcast: bool = True,
fused_softmax: Optional[bool] = None,
): ):
""" """
This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case
needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting
and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely
inefficient. TODO: Could have better fused kernels depending on scaling, dropout and head mask. inefficient.
Is it doable without writing 32 functions?
""" """
if fused_softmax is None: #assert x.size(-1) % 8 == 0
fused_softmax = x.size(-1) % 8 == 0
if upcast: if upcast:
if mask is None: if mask is None:
return (upcast_softmax_fused if fused_softmax else upcast_softmax)(x, scale, softmax_dtype) return upcast_softmax(x, scale, softmax_dtype)
else: else:
return (upcast_masked_softmax_fused if fused_softmax else upcast_masked_softmax)( return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
x, mask, mask_value, scale, softmax_dtype
)
else: else:
if mask is None: if mask is None:
return torch.nn.functional.softmax(x, dim=-1) return torch.nn.functional.softmax(x, dim=-1)
else: else:
return (masked_softmax_fused if fused_softmax else masked_softmax)(x, mask, mask_value) return masked_softmax(x, mask, mask_value)
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None): def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
super().__init__() super().__init__()
self.mask_value = None self.mask_value = None
assert config.multi_query
assert config.attention_softmax_in_fp32
assert config.scale_attention_softmax_in_fp32
self.flash_attention = config.flash_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.fused_softmax = config.fused_softmax
# KV caching and padding # KV caching and padding
self.pre_allocate_kv_cache = (
config.n_embd if config.pre_allocate_kv_cache is True else config.pre_allocate_kv_cache
)
pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length
self._tuple_cache_format = self.pre_allocate_kv_cache or pad_key_length or self.flash_attention
if self.is_cross_attention: self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device)
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device)
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
if self.flash_attention:
if flash_attn_unpadded_func is None:
raise RuntimeError(
"Flash Attention requires `flash_attn` and `einops`. "
"To install, run `pip install flash-attn einops`."
)
def _get_mask_value(self, device, dtype): def _get_mask_value(self, device, dtype):
# torch.where expects a tensor. We use a cache to avoid recreating it every time. # torch.where expects a tensor. We use a cache to avoid recreating it every time.
@ -171,14 +128,11 @@ class GPTBigCodeAttention(nn.Module):
return self.mask_value return self.mask_value
def _attn(self, query, key, value, attention_mask): def _attn(self, query, key, value, attention_mask):
dtype = query.dtype
softmax_dtype = torch.float32 softmax_dtype = torch.float32
upcast = dtype != softmax_dtype upcast = query.dtype != softmax_dtype
unscale = self.layer_idx + 1 if upcast else 1 unscale = self.layer_idx + 1 if upcast else 1
scale_factor = unscale**-1 scale_factor = unscale**-1 / self.head_dim**0.5
if self.scale_attn_weights:
scale_factor /= self.head_dim**0.5
# (batch_size, query_length, num_heads * head_dim) # (batch_size, query_length, num_heads * head_dim)
query_shape = query.shape query_shape = query.shape
@ -212,16 +166,12 @@ class GPTBigCodeAttention(nn.Module):
unscale, unscale,
softmax_dtype, softmax_dtype,
upcast, upcast,
self.fused_softmax,
) )
attn_weights = self.attn_dropout(attn_weights)
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
return attn_output, attn_weights return attn_output
def _attn_flash(self, query, key, value, attention_mask): def _attn_flash(self, query, key, value, flash_params):
query_shape = query.shape query_shape = query.shape
attn_shape = query_shape[0], self.num_heads, self.head_dim attn_shape = query_shape[0], self.num_heads, self.head_dim
@ -229,7 +179,7 @@ class GPTBigCodeAttention(nn.Module):
key = key.unsqueeze(1).expand(attn_shape) key = key.unsqueeze(1).expand(attn_shape)
value = value.unsqueeze(1).expand(attn_shape) value = value.unsqueeze(1).expand(attn_shape)
sequence_lengths, padding_index, _, max_sequence_length = attention_mask sequence_lengths, padding_index, _, max_sequence_length = flash_params
# attn_output: (sum_seq_len, num_heads * head_dim) # attn_output: (sum_seq_len, num_heads * head_dim)
attn_output = flash_attn_unpadded_func( attn_output = flash_attn_unpadded_func(
@ -240,32 +190,22 @@ class GPTBigCodeAttention(nn.Module):
sequence_lengths, sequence_lengths,
max_sequence_length, max_sequence_length,
max_sequence_length, max_sequence_length,
self.dropout_p if self.training else 0.0, 0.0,
softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1, softmax_scale=self.head_dim**-0.5,
causal=True, causal=True,
).view(query_shape) ).view(query_shape)
return attn_output, None return attn_output
def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length): def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
batch_size = kv_cache.size(-1)
assert not self.training
allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
)
allocated_kv_cache[:, :key_length].copy_(kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
return allocated_kv_cache, padded_kv_cache
def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask):
flash_attention = self.flash_attention and layer_past is None
# Convert to standard KV cache format. # Convert to standard KV cache format.
if flash_attention and use_cache: if flash_params is not None:
_, padding_index, batch_size, max_sequence_length = attention_mask _, padding_index, batch_size, max_sequence_length = flash_params
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
else: return key_value, (current_kv_cache, max_sequence_length)
current_kv_cache = key_value
current_kv_cache = key_value
# Calculate dimensions and recover layer_past # Calculate dimensions and recover layer_past
batch_size = current_kv_cache.size(0) batch_size = current_kv_cache.size(0)
@ -281,38 +221,33 @@ class GPTBigCodeAttention(nn.Module):
key_length = query_length + last_key_length key_length = query_length + last_key_length
allocated_key_length = allocated_kv_cache.size(-2) allocated_key_length = allocated_kv_cache.size(-2)
padded_key_length = key_length if flash_attention else attention_mask.size(-1) padded_key_length = attention_mask.size(-1)
allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length)
# Re-allocate kv cache and copy last value # Re-allocate kv cache and copy last value
if allocate_key_length > allocated_key_length: if padded_key_length > allocated_key_length:
allocated_kv_cache = torch.empty( allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, 2 * self.head_dim], [batch_size, padded_key_length, 2 * self.head_dim],
dtype=current_kv_cache.dtype, dtype=current_kv_cache.dtype,
device=current_kv_cache.device, device=current_kv_cache.device,
) )
if layer_past is not None: if layer_past is not None:
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
if allocate_key_length > key_length: if padded_key_length > key_length:
# Nans in `value` can propagate through the matrix multiplication, # Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.) # so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, key_length:, self.head_dim :].zero_() allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
# Copy the new values. # Copy the new values.
if allocate_key_length > allocated_key_length or layer_past is not None: if padded_key_length > allocated_key_length or layer_past is not None:
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length] padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
if not flash_attention: # Use the merged KV cache.
# Use the merged KV cache. # Not needed when layer_past is None but frees some memory.
# Not needed when layer_past is None but frees some memory. key_value = padded_kv_cache
key_value = padded_kv_cache
if use_cache: if allocated_kv_cache is None:
if allocated_kv_cache is None: allocated_kv_cache = current_kv_cache
allocated_kv_cache = current_kv_cache present = allocated_kv_cache, key_length
present = allocated_kv_cache, key_length
else:
present = None
return key_value, present return key_value, present
def forward( def forward(
@ -320,118 +255,63 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None, layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False, flash_params: Optional[Tuple] = None
output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Any]:
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
flash_attention = self.flash_attention and layer_past is None
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
if self._tuple_cache_format: # present = (allocated_kv_cache, key_length)
# present = (allocated_kv_cache, key_length) key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params)
key_value, present = self._merge_kv_caches(key_value, use_cache, layer_past, attention_mask)
else:
# present = key_value
if layer_past is not None:
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)( if flash_params is None:
query, key, value, attention_mask attn_output=self._attn(query, key, value, attention_mask)
) else:
attn_output=self._attn_flash(query, key, value, flash_params)
attn_output = self.c_proj(attn_output) attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present) return attn_output, present
if output_attentions:
if flash_attention:
raise ValueError("`output_attentions` is not supported with Flash Attention.")
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class GPTBigCodeMLP(nn.Module): class GPTBigCodeMLP(nn.Module):
def __init__(self, intermediate_size, config): def __init__(self, config):
super().__init__() super().__init__()
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.c_fc = nn.Linear(embed_dim, intermediate_size) inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
self.c_proj = nn.Linear(intermediate_size, embed_dim) self.c_fc = nn.Linear(embed_dim, inner_dim)
self.act = ACT2FN[config.activation_function] self.c_proj = nn.Linear(inner_dim, embed_dim)
self.dropout = nn.Dropout(config.resid_pdrop)
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
hidden_states = self.c_fc(hidden_states) return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh"))
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class GPTBigCodeBlock(nn.Module): class GPTBigCodeBlock(nn.Module):
def __init__(self, config, layer_idx=None): def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
super().__init__() super().__init__()
hidden_size = config.hidden_size self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device)
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigCodeMLP(config)
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
def forward( def forward(
self, self,
hidden_states: Optional[Tuple[torch.Tensor]], hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None, layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, flash_params: Optional[Tuple] = None
encoder_attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Any]:
use_cache: Optional[bool] = False, attn_output, present = self.attn(
output_attentions: Optional[bool] = False, self.ln_1(hidden_states),
) -> Union[
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]:
if encoder_hidden_states is not None or encoder_attention_mask is not None:
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache, flash_params=flash_params,
output_attentions=output_attentions,
) )
attn_output = attn_outputs[0] # output_attn: a, present, (attentions) hidden_states.add_(attn_output)
outputs = attn_outputs[1:] hidden_states.add_(self.mlp(self.ln_2(hidden_states)))
# residual connection return hidden_states, present
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions)
class GPTBigCodePreTrainedModel(PreTrainedModel): class GPTBigCodePreTrainedModel(PreTrainedModel):
@ -442,7 +322,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
config_class = GPTBigCodeConfig config_class = GPTBigCodeConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = False
_no_split_modules = ["GPTBigCodeBlock"] _no_split_modules = ["GPTBigCodeBlock"]
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
@ -450,7 +330,9 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): if isinstance(module, GPTBigCodeModel):
module.bias.fill_(True).tril_()
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
@ -475,35 +357,27 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTBigCodeModel):
module.gradient_checkpointing = value
class GPTBigCodeModel(GPTBigCodePreTrainedModel): class GPTBigCodeModel(GPTBigCodePreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"] def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
def __init__(self, config):
super().__init__(config) super().__init__(config)
assert config.multi_query
self.embed_dim = config.hidden_size
if config.add_cross_attention: self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device)
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device)
self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)])
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon)
self.drop = nn.Dropout(config.embd_pdrop) self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner)
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length self.flash_attention = True #config.flash_attention
self._tuple_cache_format = config.pre_allocate_kv_cache or self.pad_key_length or config.flash_attention
self.inference_runner_type = InferenceRunnerType(config.inference_runner)
self.flash_attention = config.flash_attention if self.flash_attention:
if flash_attn_unpadded_func is None:
raise RuntimeError(
"Flash Attention requires `flash_attn` and `einops`. "
"To install, run `pip install flash-attn einops`."
)
if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: if self.inference_runner_type == InferenceRunnerType.NO_RUNNER:
self.inference_runner = None self.inference_runner = None
@ -512,23 +386,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
self.inference_runner = GPTBigCodeInferenceRunner(config, self) self.inference_runner = GPTBigCodeInferenceRunner(config, self)
max_positions = config.max_position_embeddings
# Causal mask # Causal mask
self.register_buffer( self.register_buffer(
"bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False "bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device)
) )
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def _get_causal_mask(self, padding_mask, query_length, key_length): def _get_causal_mask(self, padding_mask, query_length, key_length):
# Self-attention mask. # Self-attention mask.
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
@ -537,306 +399,105 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
attention_mask = attention_mask * padding_mask.unsqueeze(1).to( attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
dtype=torch.bool, device=attention_mask.device dtype=torch.bool, device=attention_mask.device
) )
pad = -key_length % 8
if pad > 0:
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
# MQA models: (batch_size, query_length, n_heads, key_length) # (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
return attention_mask.unsqueeze(2) return attention_mask.unsqueeze(2)
def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device):
if position_ids is not None:
position_ids = position_ids.to(device)
elif padding_mask is not None and padding_mask.ndim == 2:
# create position_ids on the fly for batch generation
position_ids = padding_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(padding_mask == 0, 1)
if key_length > query_length:
position_ids = position_ids[:, key_length - query_length : key_length :]
else:
position_ids = torch.arange(key_length - query_length, key_length, dtype=torch.long, device=device)
return position_ids.view(-1, query_length)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, *,
input_ids: torch.Tensor,
past_key_values: Optional[Union[List[torch.Tensor], int]] = None, past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, position_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None, ) -> Tuple:
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
if self.inference_runner is not None and past_key_values is not None: if self.inference_runner is not None and past_key_values is not None:
if self.config.validate_runner_input: if self.config.validate_runner_input:
assert input_ids is not None
assert past_key_values is not None assert past_key_values is not None
assert attention_mask is not None
assert token_type_ids is None
assert position_ids is not None
assert inputs_embeds is None
assert encoder_hidden_states is None
assert encoder_attention_mask is None
use_cache = use_cache if use_cache is not None else self.config.use_cache
assert use_cache is True
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
assert output_attentions is False
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
assert output_hidden_states is False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
assert return_dict is True
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions batch_size, query_length = input_ids.shape
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = self.config.use_cache if use_cache is None else use_cache
return_dict = self.config.use_return_dict if return_dict is None else return_dict
if input_ids is not None:
if inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
input_shape = input_ids.shape
input_ids = input_ids.view(-1, input_shape[-1])
batch_size, query_length = input_ids.shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
inputs_embeds = inputs_embeds.view(-1, input_shape[-2:])
batch_size, query_length = inputs_embeds.shape[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
flash_attention = self.flash_attention and past_key_values is None flash_attention = self.flash_attention and past_key_values is None
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = tuple([None] * len(self.h)) past_key_values = tuple([None] * len(self.h))
elif self._tuple_cache_format:
past_length = past_key_values[0][1]
else: else:
past_length = past_key_values[0].size(-2) past_length = past_key_values[0][1]
key_length = past_length + query_length
position_ids = self._get_position_ids(position_ids, attention_mask, query_length, key_length, input_ids.device)
if token_type_ids is not None: hidden_states = self.wte(input_ids) + self.wpe(position_ids)
token_type_ids = token_type_ids.view(-1, query_length)
if not flash_attention:
# Self-attention mask (padding + causal).
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
if self.pad_key_length:
pad = -key_length % 8
if pad > 0:
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
if encoder_hidden_states is not None or encoder_attention_mask is not None:
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
# TODO: Unpad earlier (input ids), support unpadded input? # TODO: Unpad earlier (input ids), support unpadded input?
if flash_attention: if flash_attention:
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input(
hidden_states, attention_mask hidden_states, attention_mask
) )
# Pass the required parameters through the attention_mask argument flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
attention_mask = (sequence_lengths, padding_index, batch_size, max_sequence_length) attention_mask=None
else:
key_length = past_length + query_length
# Self-attention mask (padding + causal).
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
flash_params=None
presents = [] if use_cache else None presents = []
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: outputs = block(
hidden_states,
def create_custom_forward(module): layer_past=layer_past,
def custom_forward(*inputs): attention_mask=attention_mask,
# None for past_key_value flash_params=flash_params
return module(*inputs, use_cache, output_attentions) )
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache: presents.append(outputs[1])
presents.append(outputs[1])
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
if flash_attention: if flash_attention:
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
hidden_states = hidden_states.view(input_shape + (hidden_states.size(-1),)) return hidden_states, presents
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = GPTBigCodeModel(config) meta=torch.device("meta")
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
self.predict_last_token = config.predict_last_token self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta)
self.to_empty(device=device)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": kwargs.get("position_ids", None),
"attention_mask": kwargs.get("attention_mask", None),
"token_type_ids": token_type_ids,
}
)
return model_inputs
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, *,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, input_ids: torch.Tensor,
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, position_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None, predict_all_tokens: bool=True,
inputs_embeds: Optional[torch.Tensor] = None, ) -> Tuple:
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer( hidden_states, presents=self.transformer(
input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, position_ids=position_ids
position_ids=position_ids,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
hidden_states = transformer_outputs[0]
if self.predict_last_token and not self.training: if not predict_all_tokens:
# We only care about the last token. # We only care about the last token.
hidden_states = hidden_states[:, -1:] hidden_states = hidden_states[:, -1:]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = None return lm_logits, presents
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)

View File

@ -1,13 +1,12 @@
from typing import List, Union from typing import List, Union, Tuple
import torch import torch
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
InferenceRunnerType, InferenceRunnerType,
) )
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, softmax_function from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeBlock, softmax_function
def _align_tensor(x): def _align_tensor(x):
@ -291,7 +290,7 @@ class GPTBigCodeInferenceRunner:
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
past_key_values: Union[List[torch.Tensor], int], past_key_values: Union[List[torch.Tensor], int],
) -> BaseModelOutputWithPastAndCrossAttentions: ) -> Tuple:
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
assert query_length == 1 assert query_length == 1
if self.batch_size is None: if self.batch_size is None:
@ -333,7 +332,4 @@ class GPTBigCodeInferenceRunner:
else: else:
hidden_states = self._forward(key_length) hidden_states = self._forward(key_length)
return BaseModelOutputWithPastAndCrossAttentions( return hidden_states, key_length
last_hidden_state=hidden_states,
past_key_values=key_length,
)

View File

@ -449,6 +449,49 @@ class FlashCausalLM(Model):
pre_allocate_past_size=pre_allocate_past_size, pre_allocate_past_size=pre_allocate_past_size,
) )
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, use_cache: Optional[torch.dtype]):
diff = max_input_length - max(batch.input_lengths)
for i in range(len(batch)):
batch.input_lengths[i] += diff
batch.prefix_offsets[i] = 0
batch.read_offsets[i] = 0
batch.all_input_ids[i] = batch.all_input_ids[i] + [self.tokenizer.pad_token_id] * diff if diff >= 0 else batch.all_input_ids[i][:diff]
# TODO: Bug!?!
batch.stopping_criterias[i].current_tokens += diff
if use_cache:
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
batch.input_ids.fill_(self.tokenizer.pad_token_id)
batch.position_ids += diff
batch.cu_seqlens += diff * batch.cu_seqlens_q
# TODO: Bug!?!
batch.max_seqlen += batch.max_seqlen + diff*len(batch)
for i in range(len(batch)):
batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(self.tokenizer.pad_token_id)
batch.past_key_values = batch.past_key_values = torch.randn(
(
batch.past_key_values.shape[0],
batch.past_key_values.shape[1] + len(batch.requests),
*batch.past_key_values.shape[2:],
), device=batch.past_key_values.device, dtype= batch.past_key_values.dtype
)
else:
batch.max_seqlen = max(batch.input_lengths)
batch.all_input_ids_tensor=[]
batch.input_ids = torch.tensor(
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
)
batch.position_ids = torch.tensor(
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
)
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
batch.past_key_values=None
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch

View File

@ -223,6 +223,7 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size, world_size=world_size,
) )
@staticmethod @staticmethod
def load_weights( def load_weights(
model, model,

View File

@ -0,0 +1,135 @@
import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer
from typing import Optional, Type
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch
from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeForCausalLM
tracer = trace.get_tracer(__name__)
@dataclass
class BigcodeBatch(VectorizedCausalLMBatch):
kv_cache_seq_dim: int = 1
def _filter_kv_caches(self, keep_indices, sequence_slice):
if self.past_key_values is not None:
for layer_kv, _ in self.past_key_values:
# Update tensors in-place to allow incremental garbage collection
layer_kv.data = layer_kv[keep_indices, sequence_slice]
@classmethod
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
device = batches[0].input_ids.device
batch_size = sum([len(batch.requests) for batch in batches])
for batch in batches:
if batch.past_key_values is None:
raise ValueError("Only concatenate prefilled batches")
past_key_values = []
for kv_caches in zip(*(batch.past_key_values for batch in batches)):
key_values, seq_lengths = zip(*kv_caches)
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths))
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values))
allocate_seq_len += - allocate_seq_len % 8
kv_cache = torch.empty(
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device
)
for key_value, start_index, end_index, left_index in zip(
key_values,
start_indices,
end_indices,
left_indices,
):
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value)
# Set padding to zero to avoid propagating nans.
kv_cache[start_index:end_index, :left_index].fill_(0)
kv_cache[start_index:end_index, max_input_length:].fill_(0)
past_key_values.append((kv_cache, max_input_length))
def __len__(self):
return len(self.requests)
class BigcodeCausalLM(VectorizedCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
model = GPTBigCodeForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes",
)
tokenizer.pad_token_id = (
model.config.pad_token_id
if model.config.pad_token_id is not None
else model.config.eos_token_id
)
super(VectorizedCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[BigcodeBatch]:
return BigcodeBatch
def forward(self, batch:BigcodeBatch):
key_length = batch.max_input_length
query_length = key_length if batch.past_key_values is None else 1
input_ids = batch.input_ids[:, key_length - query_length : key_length]
# Model Forward
logits, batch.past_key_values = self.model.forward(
input_ids=input_ids,
attention_mask=batch.attention_mask[:, :key_length],
position_ids=batch.position_ids[:, key_length - query_length : key_length],
past_key_values=batch.past_key_values,
use_cache=True,
)
next_token_ids, logprobs = batch.next_token_chooser(
input_ids, logits, batch.details
)
# Update batch
# TODO: Why do we need all input ids?
batch.input_ids[:, key_length].copy_(next_token_ids)
batch.input_lengths = [length + 1 for length in batch.input_lengths]
batch.max_input_length += 1
return next_token_ids, logprobs
def mock_kv_cache(self, batch: BigcodeBatch, dtype:Optional[torch.dtype]):
allocate_length=batch.max_input_length+-batch.max_input_length%8
return [(torch.empty(
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
dtype=dtype,
device=batch.input_ids.device,
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)]

View File

@ -0,0 +1,149 @@
import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer
from typing import Optional, Type
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import GPTBigCodeForCausalLM
tracer = trace.get_tracer(__name__)
@dataclass
class Bigcode2Batch(VectorizedCausalLMBatch):
kv_cache_seq_dim: int = 1
pad_key_length_to_multiple:int=8
# Prefill the attention mask for padded key length.
attention_mask_fill_value=False
def _filter_kv_caches(self, keep_indices, sequence_slice):
if self.past_key_values is not None:
for layer_kv, _ in self.past_key_values:
# Update tensors in-place to allow incremental garbage collection
layer_kv.data = layer_kv[keep_indices, sequence_slice]
@classmethod
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
device = batches[0].input_ids.device
batch_size = sum([len(batch.requests) for batch in batches])
for batch in batches:
if batch.past_key_values is None:
raise ValueError("Only concatenate prefilled batches")
past_key_values = []
for kv_caches in zip(*(batch.past_key_values for batch in batches)):
key_values, seq_lengths = zip(*kv_caches)
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths))
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values))
allocate_seq_len += - allocate_seq_len % batches[0].pad_key_length_to_multiple
kv_cache = torch.empty(
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device
)
for key_value, start_index, end_index, left_index in zip(
key_values,
start_indices,
end_indices,
left_indices,
):
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value)
# Set padding to zero to avoid propagating nans.
kv_cache[start_index:end_index, :left_index].fill_(0)
kv_cache[start_index:end_index, max_input_length:].fill_(0)
past_key_values.append((kv_cache, max_input_length))
def __len__(self):
return len(self.requests)
class Bigcode2CausalLM(VectorizedCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
model = GPTBigCodeForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes",
)
tokenizer.pad_token_id = (
model.config.pad_token_id
if model.config.pad_token_id is not None
else model.config.eos_token_id
)
super(VectorizedCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[Bigcode2Batch]:
return Bigcode2Batch
def forward(self, batch:Bigcode2Batch):
key_length = batch.max_input_length
if batch.past_key_values is None:
# Prefill (flash attn, unpadded key length)
batch.pad_key_length_to_multiple=self.model.pad_key_length_to_multiple
padded_key_length=key_length
query_length=key_length
else:
# Decode (fused attn, padded key length)
batch.attention_mask[:, key_length-1].fill_(True)
padded_key_length=key_length+-key_length%batch.pad_key_length_to_multiple
query_length=1
input_ids = batch.input_ids[:, key_length - query_length : key_length]
# Model Forward
logits, batch.past_key_values = self.model.forward(
input_ids=input_ids,
attention_mask=batch.attention_mask[:, :padded_key_length],
position_ids=batch.position_ids[:, key_length - query_length : key_length],
past_key_values=batch.past_key_values,
key_length=key_length,
predict_all_tokens=batch.details
)
next_token_ids, logprobs = batch.next_token_chooser(
input_ids, logits, batch.details
)
# Update batch
# TODO: Why do we need all input ids?
batch.input_ids[:, key_length].copy_(next_token_ids)
batch.input_lengths = [length + 1 for length in batch.input_lengths]
batch.max_input_length += 1
return next_token_ids, logprobs
def mock_kv_cache(self, batch: Bigcode2Batch, dtype:Optional[torch.dtype]):
allocate_length=batch.max_input_length+-batch.max_input_length%batch.pad_key_length_to_multiple
return [(torch.empty(
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
dtype=dtype,
device=batch.input_ids.device,
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)]

View File

@ -58,6 +58,9 @@ class VectorizedCausalLMBatch(Batch):
kv_cache_seq_dim: int = 2 kv_cache_seq_dim: int = 2
# Prefill the attention mask for the generated tokens
attention_mask_fill_value=True
# TODO: Get from requests (should these be lists?) # TODO: Get from requests (should these be lists?)
details: bool = os.environ.get("RETURN_DETAILS") is not None details: bool = os.environ.get("RETURN_DETAILS") is not None
generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None
@ -113,7 +116,7 @@ class VectorizedCausalLMBatch(Batch):
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"]) attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"])
attention_mask[:, max_input_length:].fill_(1) attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value)
position_ids = attention_mask.cumsum(-1).sub_(1) position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_() position_ids[:, :max_input_length].relu_()
@ -268,7 +271,7 @@ class VectorizedCausalLMBatch(Batch):
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
attention_mask[:, :max_input_length].fill_(0) attention_mask[:, :max_input_length].fill_(0)
attention_mask[:, max_input_length:].fill_(1) attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value)
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device) input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
# TODO : only needed for prefill # TODO : only needed for prefill
@ -293,7 +296,7 @@ class VectorizedCausalLMBatch(Batch):
) )
kv_cache_seq_dim = batches[0].kv_cache_seq_dim kv_cache_seq_dim = batches[0].kv_cache_seq_dim
past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices) past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices, max_input_length)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
@ -314,7 +317,7 @@ class VectorizedCausalLMBatch(Batch):
) )
@classmethod @classmethod
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices): def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
device = batches[0].input_ids.device device = batches[0].input_ids.device
batch_size = sum([len(batch.requests) for batch in batches]) batch_size = sum([len(batch.requests) for batch in batches])
@ -355,29 +358,23 @@ class VectorizedCausalLMBatch(Batch):
else batch.past_key_values[i][j] else batch.past_key_values[i][j]
for batch in batches for batch in batches
] ]
# Generally `max_input_length`, unless the model allocates more than needed.
right_indices = [
left_index + tensor.size(kv_cache_seq_dim)
for tensor, left_index in zip(tensors_to_merge, left_indices)
]
combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:]) combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:])
combined_shape[kv_cache_seq_dim] = max(right_indices) combined_shape[kv_cache_seq_dim] = max_input_length
# Set to zero to avoid propagating nans in padded values. # Set to zero to avoid propagating nans in padded values.
kv_cache = torch.zeros( kv_cache = torch.zeros(
combined_shape, dtype=tensors_to_merge[0].dtype, device=device combined_shape, dtype=tensors_to_merge[0].dtype, device=device
) )
for tensor, start_index, end_index, left_index, right_index in zip( for tensor, start_index, end_index, left_index in zip(
tensors_to_merge, tensors_to_merge,
start_indices, start_indices,
end_indices, end_indices,
left_indices, left_indices,
right_indices,
): ):
kv_cache[ kv_cache[
[ [
slice(start_index, end_index), slice(start_index, end_index),
*(slice(None) for _ in range(1, kv_cache_seq_dim)), *(slice(None) for _ in range(1, kv_cache_seq_dim)),
slice(left_index, right_index), slice(left_index, max_input_length),
] ]
].copy_(tensor) ].copy_(tensor)
if kv_format is None: if kv_format is None:
@ -398,13 +395,11 @@ class VectorizedCausalLM(Model):
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: bool = False, quantize: Optional[str] = None,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
# TODO: Choose dtype (fp16?) dtype = torch.float16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
@ -415,26 +410,26 @@ class VectorizedCausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
self.model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, trust_remote_code=True,
).eval() )
tokenizer.pad_token_id = ( tokenizer.pad_token_id = (
self.model.config.pad_token_id model.config.pad_token_id
if self.model.config.pad_token_id is not None if model.config.pad_token_id is not None
else self.model.config.eos_token_id else model.config.eos_token_id
) )
super().__init__( super().__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=decode_buffer,
) )
@property @property
@ -446,6 +441,30 @@ class VectorizedCausalLM(Model):
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
) )
def forward(self, batch:VectorizedCausalLMBatch):
key_length = batch.max_input_length
query_length = key_length if batch.past_key_values is None else 1
input_ids = batch.input_ids[:, key_length - query_length : key_length]
# Model Forward
logits, batch.past_key_values, *_ = self.model.forward(
input_ids=input_ids,
attention_mask=batch.attention_mask[:, :key_length],
position_ids=batch.position_ids[:, key_length - query_length : key_length],
past_key_values=batch.past_key_values,
return_dict=False,
use_cache=True,
)
next_token_ids, logprobs = batch.next_token_chooser(
input_ids, logits, batch.details
)
# Update batch
# TODO: Why do we need all input ids?
batch.input_ids[:, key_length].copy_(next_token_ids)
batch.input_lengths = [length + 1 for length in batch.input_lengths]
batch.max_input_length += 1
return next_token_ids, logprobs
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: VectorizedCausalLMBatch self, batch: VectorizedCausalLMBatch
@ -453,20 +472,9 @@ class VectorizedCausalLM(Model):
key_length = batch.max_input_length key_length = batch.max_input_length
if key_length > batch.input_ids.size(1): if key_length > batch.input_ids.size(1):
raise RuntimeError("Cannot generate more than `max_tokens`.") raise RuntimeError("Cannot generate more than `max_tokens`.")
is_prefill = batch.past_key_values is None
query_length = key_length if batch.past_key_values is None else 1 next_token_ids, logprobs = self.forward(batch)
input_ids = batch.input_ids[:, key_length - query_length : key_length]
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=batch.attention_mask[:, :key_length],
position_ids=batch.position_ids[:, key_length - query_length : key_length],
past_key_values=batch.past_key_values,
)
# TODO: Post-processing
next_token_ids, logprobs = batch.next_token_chooser(
input_ids, outputs.logits, batch.details
)
if batch.generate_stream: if batch.generate_stream:
# TODO: self.decode_token, offsets? # TODO: self.decode_token, offsets?
@ -479,7 +487,6 @@ class VectorizedCausalLM(Model):
.squeeze(1) .squeeze(1)
.tolist() .tolist()
) )
is_prefill = batch.past_key_values is None
if is_prefill: if is_prefill:
prefill_token_ids = batch.input_ids[:, :key_length].tolist() prefill_token_ids = batch.input_ids[:, :key_length].tolist()
prefill_logprobs = ( prefill_logprobs = (
@ -491,7 +498,8 @@ class VectorizedCausalLM(Model):
for prefill_token_ids_, prefill_logprobs_, input_length in zip( for prefill_token_ids_, prefill_logprobs_, input_length in zip(
prefill_token_ids, prefill_logprobs, batch.input_lengths prefill_token_ids, prefill_logprobs, batch.input_lengths
): ):
prefill_token_ids_ = prefill_token_ids_[-input_length:] # Input length has already been incremented so we subtract 1.
prefill_token_ids_ = prefill_token_ids_[-(input_length-1):]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids_, prefill_token_ids_,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
@ -505,13 +513,6 @@ class VectorizedCausalLM(Model):
) )
) )
# Update batch
# TODO: Why do we need all input ids?
batch.input_ids[:, key_length].copy_(next_token_ids)
batch.past_key_values = outputs.past_key_values
batch.input_lengths = [length + 1 for length in batch.input_lengths]
batch.max_input_length += 1
# TODO: Vectorize some of this? # TODO: Vectorize some of this?
generations: List[Generation] = [] generations: List[Generation] = []
@ -556,3 +557,24 @@ class VectorizedCausalLM(Model):
generations.append(generation) generations.append(generation)
return generations, next_batch return generations, next_batch
def mock_kv_cache(self, batch: VectorizedCausalLMBatch, dtype:Optional[torch.dtype]):
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
if not isinstance(self.model, GPTBigCodeForCausalLM):
raise NotImplementedError()
return [torch.empty(
[len(batch), batch.max_input_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
dtype=dtype,
device=batch.input_ids.device,
) for _ in range(self.model.config.n_layer)]
def fast_forward(self, batch: VectorizedCausalLMBatch, max_input_length: int, cache_dtype:Optional[torch.dtype]):
diff=max_input_length-batch.max_input_length
batch.input_ids[:, batch.max_input_length:max_input_length].fill_(self.tokenizer.pad_token_id)
batch.input_lengths = [length + diff for length in batch.input_lengths]
batch.max_input_length += diff
for stopping_criteria in batch.stopping_criterias:
stopping_criteria.current_tokens+=diff
batch.past_key_values = None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype)