mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Optimized model
This commit is contained in:
parent
a6dd19b042
commit
0921fe6a2a
@ -17,6 +17,8 @@ from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_bigcode import BigcodeCausalLM
|
||||
from text_generation_server.models.gpt_bigcode2 import Bigcode2CausalLM
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available() and os.environ.get("NO_FLASH_ATTENTION") is None:
|
||||
@ -91,10 +93,25 @@ torch.backends.cudnn.allow_tf32 = True
|
||||
# Disable gradients
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
model_type_map={
|
||||
"flash":FlashSantacoder,
|
||||
"santa":SantaCoder,
|
||||
"causal":CausalLM,
|
||||
"vector":VectorizedCausalLM,
|
||||
"bigcode":BigcodeCausalLM,
|
||||
"bigcode2":Bigcode2CausalLM,
|
||||
}
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
|
||||
) -> Model:
|
||||
model_type=os.environ.get("MODEL_TYPE")
|
||||
if model_type is not None:
|
||||
if model_type not in model_type_map:
|
||||
raise NotImplementedError(model_type)
|
||||
return model_type_map[model_type](model_id, revision, quantize=quantize)
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||
@ -167,8 +184,6 @@ def get_model(
|
||||
raise ValueError("sharded is not supported for AutoModel")
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
if os.environ.get("VECTORIZED_LM") is not None:
|
||||
return VectorizedCausalLM(model_id, revision, quantize=quantize)
|
||||
return CausalLM(model_id, revision, quantize=quantize)
|
||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||
|
@ -0,0 +1,406 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch GPTBigCode model."""
|
||||
import math
|
||||
from typing import Optional, Tuple, Any, Union, List
|
||||
from enum import IntEnum
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from dropout_layer_norm import dropout_layer_norm
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
||||
GPTBigCodeConfig,
|
||||
)
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
# TODO: Validate dimension
|
||||
def forward(self, hidden_states, residual=None):
|
||||
return dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
class FastLinear(nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return torch.addmm(self.bias, input, self.weight)
|
||||
|
||||
class FastLinearNoBias(nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return torch.mm(input, self.weight)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@torch.jit.script
|
||||
def upcast_masked_softmax(
|
||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float
|
||||
):
|
||||
input_dtype = x.dtype
|
||||
x = x.to(torch.float32) * scale
|
||||
x = torch.where(mask, x, mask_value)
|
||||
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
||||
return x
|
||||
|
||||
@torch.jit.script
|
||||
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
||||
x = torch.where(mask, x, mask_value)
|
||||
x = torch.nn.functional.softmax(x, dim=-1)
|
||||
return x
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.layer_idx = layer_idx
|
||||
# Note: Does not support module dtype conversion.
|
||||
self.register_buffer("mask_value", torch.empty((), dtype=torch.float32, device="meta"))
|
||||
|
||||
self.c_attn = FastLinear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device="meta")
|
||||
self.c_proj = FastLinear(self.embed_dim, self.embed_dim, dtype=dtype, device="meta")
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sequence_lengths,
|
||||
key_length:int,
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
hidden_shape = hidden_states.shape
|
||||
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
||||
query = query.view(hidden_shape[0], self.num_heads, self.head_dim)
|
||||
key, value = key_value.unsqueeze(1).expand(hidden_shape[0], self.num_heads, 2*self.head_dim).split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
# attn_output: (sum_seq_len, num_heads * head_dim)
|
||||
attn_output = flash_attn_unpadded_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sequence_lengths,
|
||||
sequence_lengths,
|
||||
key_length,
|
||||
key_length,
|
||||
0.0,
|
||||
softmax_scale=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
).view(hidden_shape)
|
||||
|
||||
attn_output = self.c_proj.forward(attn_output)
|
||||
|
||||
return attn_output, key_value
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size:int,
|
||||
key_length:int,
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
||||
|
||||
# Calculate dimensions and recover layer_past
|
||||
padded_key_length = attention_mask.size(-1)
|
||||
allocated_key_length=layer_past.size(-2)
|
||||
|
||||
# TODO: Allow pre-allocation with size > padded_key_length
|
||||
if padded_key_length > allocated_key_length:
|
||||
# Re-allocate kv cache and copy last value
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, padded_key_length, 2 * self.head_dim],
|
||||
dtype=key_value.dtype,
|
||||
device=key_value.device,
|
||||
)
|
||||
allocated_kv_cache[:, :key_length-1].copy_(layer_past[:, :key_length-1])
|
||||
# Nans in `value` can propagate through the matrix multiplication,
|
||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||
allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_()
|
||||
layer_past=allocated_kv_cache
|
||||
|
||||
# Copy the new values.
|
||||
layer_past[:, key_length-1:key_length].copy_(key_value)
|
||||
|
||||
key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
# TODO: Upcasting needed for bf16?
|
||||
upcast = query.dtype != torch.float32
|
||||
unscale = self.layer_idx + 1 if upcast else 1
|
||||
scale_factor = unscale ** -1 / self.head_dim ** 0.5
|
||||
|
||||
hidden_states = torch.baddbmm(
|
||||
torch.empty((batch_size, self.num_heads, padded_key_length), device=query.device, dtype=query.dtype),
|
||||
query.view(batch_size, self.num_heads, self.head_dim),
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=scale_factor
|
||||
).unsqueeze_(1)
|
||||
|
||||
if self.mask_value is None or self.mask_value.device != hidden_states.device:
|
||||
self.mask_value = torch.full([], torch.finfo(torch.float32).min, dtype=torch.float32, device=hidden_states.device)
|
||||
|
||||
if upcast:
|
||||
hidden_states = upcast_masked_softmax(hidden_states, attention_mask, self.mask_value, unscale)
|
||||
else:
|
||||
hidden_states = masked_softmax(hidden_states, attention_mask, self.mask_value)
|
||||
|
||||
hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape)
|
||||
|
||||
hidden_states = self.c_proj.forward(hidden_states)
|
||||
|
||||
return hidden_states, layer_past
|
||||
|
||||
class GPTBigCodeMLP(nn.Module):
|
||||
# TODO: Merge into GPTBigCodeBlock (needs renaming in state dict)
|
||||
def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
|
||||
self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta")
|
||||
self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta")
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype):
|
||||
super().__init__()
|
||||
self.ln_1 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta")
|
||||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype)
|
||||
self.ln_2 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta")
|
||||
self.mlp = GPTBigCodeMLP(config, dtype=dtype)
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
sequence_lengths,
|
||||
key_length:int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
||||
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
|
||||
hidden_states, present = self.attn.prefill(
|
||||
hidden_states,
|
||||
sequence_lengths=sequence_lengths,
|
||||
key_length=key_length,
|
||||
)
|
||||
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
||||
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh"))
|
||||
return hidden_states, residual, present
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
layer_past: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size:int,
|
||||
key_length:int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
||||
hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual)
|
||||
hidden_states, present = self.attn.decode(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
batch_size=batch_size,
|
||||
key_length=key_length,
|
||||
)
|
||||
hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual)
|
||||
hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh"))
|
||||
return hidden_states, residual, present
|
||||
|
||||
|
||||
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = GPTBigCodeConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["GPTBigCodeBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, GPTBigCodeModel):
|
||||
module.bias.fill_(True).tril_()
|
||||
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
|
||||
if isinstance(module, GPTBigCodeAttention):
|
||||
module.mask_value.fill_(torch.finfo(torch.float32).min)
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
module.c_proj.weight.data.normal_(
|
||||
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
||||
)
|
||||
module.c_proj._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
# TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict)
|
||||
def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype):
|
||||
super().__init__(config)
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device="meta")
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device="meta")
|
||||
|
||||
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) for i in range(config.num_hidden_layers)])
|
||||
self.ln_f = FastLayerNorm(config.hidden_size, dtype=dtype, device="meta", eps=config.layer_norm_epsilon)
|
||||
|
||||
# Causal mask
|
||||
self.register_buffer(
|
||||
"causal_mask", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device="meta")
|
||||
)
|
||||
|
||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
pad_key_length_to_multiple=8
|
||||
|
||||
def __init__(self, config, dtype:torch.dtype, device:torch.device=torch.device("cuda")):
|
||||
super().__init__(config)
|
||||
if device.type!="cuda":
|
||||
raise NotImplementedError(f"Device {device} not supported")
|
||||
|
||||
self.transformer = GPTBigCodeModel(config, dtype=dtype)
|
||||
self.lm_head = FastLinearNoBias(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta")
|
||||
|
||||
self.to_empty(device=device)
|
||||
|
||||
self._apply=self._apply_not_allowed
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _apply_not_allowed(self):
|
||||
# Dtype or device conversion would break the model.
|
||||
raise NotImplementedError("Device or dtype conversion not supported!")
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor = None,
|
||||
position_ids: torch.Tensor,
|
||||
predict_all_tokens: bool=True,
|
||||
) -> Tuple:
|
||||
batch_size, query_length = input_ids.shape
|
||||
|
||||
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids)
|
||||
|
||||
# Prefill (flash attn)
|
||||
# TODO: Unpad earlier (input ids)?
|
||||
hidden_states, padding_index, sequence_lengths, key_length = unpad_input(hidden_states, attention_mask)
|
||||
assert key_length==query_length
|
||||
|
||||
residual = None
|
||||
past_key_values = []
|
||||
block:GPTBigCodeBlock
|
||||
for block in self.transformer.h:
|
||||
hidden_states, residual, key_value = block.prefill(
|
||||
hidden_states,
|
||||
residual=residual,
|
||||
sequence_lengths=sequence_lengths,
|
||||
key_length=query_length,
|
||||
)
|
||||
past_key_values.append(pad_input(key_value, padding_index, batch_size, query_length))
|
||||
|
||||
hidden_states = self.transformer.ln_f.forward(hidden_states, residual)
|
||||
|
||||
# Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible.
|
||||
del residual
|
||||
|
||||
if predict_all_tokens:
|
||||
hidden_states = self.lm_head.forward(hidden_states)
|
||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
|
||||
else:
|
||||
# TODO: Index directly instead
|
||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)[:, -1]
|
||||
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
def decode(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.Tensor,
|
||||
past_key_values: List[torch.Tensor],
|
||||
attention_mask: [torch.Tensor],
|
||||
position_ids: torch.Tensor,
|
||||
key_length:int,
|
||||
) -> Tuple:
|
||||
|
||||
batch_size, query_length = input_ids.shape
|
||||
assert query_length == 1
|
||||
|
||||
hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids)
|
||||
|
||||
# Standardize shape to (batch_size, hidden_size)
|
||||
hidden_states.squeeze_(1)
|
||||
|
||||
# Self-attention mask (padding + causal).
|
||||
# TODO: Avoid unsqueeze
|
||||
attention_mask = self.transformer.causal_mask[None, key_length - 1: key_length,
|
||||
:key_length] * attention_mask.unsqueeze(1)
|
||||
attention_mask.unsqueeze_(2)
|
||||
|
||||
residual = None
|
||||
block:GPTBigCodeBlock
|
||||
for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)):
|
||||
hidden_states, residual, past_key_values[i] = block.decode(
|
||||
hidden_states,
|
||||
residual=residual,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
batch_size=batch_size,
|
||||
key_length=key_length,
|
||||
)
|
||||
|
||||
hidden_states = self.transformer.ln_f.forward(hidden_states, residual)
|
||||
hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1)
|
||||
|
||||
return hidden_states, past_key_values
|
@ -0,0 +1,520 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch GPTBigCode model."""
|
||||
import math
|
||||
from typing import Optional, Tuple, Any, Union, List
|
||||
from enum import IntEnum
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
||||
GPTBigCodeConfig,
|
||||
)
|
||||
|
||||
class InferenceRunnerType(IntEnum):
|
||||
NO_RUNNER = 0
|
||||
# Use the inference runner without cuda graphs.
|
||||
BASE_RUNNER = 1
|
||||
# Use cuda graphs in the inference runner. Leave out the attention which has a variable shape.
|
||||
# This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models.
|
||||
PARTIAL_GRAPH = 2
|
||||
# Turn the whole model into a cuda graph. One graph for each sequence length.
|
||||
# Note: only useful for small batches and models, graphs take some time to generate, flaky.
|
||||
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
|
||||
FULL_GRAPH = 3
|
||||
|
||||
try:
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
except ImportError:
|
||||
flash_attn_unpadded_func = None
|
||||
pad_input = None
|
||||
unpad_input = None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def upcast_masked_softmax(
|
||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
|
||||
):
|
||||
input_dtype = x.dtype
|
||||
x = x.to(softmax_dtype) * scale
|
||||
x = torch.where(mask, x, mask_value)
|
||||
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
||||
return x
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
|
||||
input_dtype = x.dtype
|
||||
x = x.to(softmax_dtype) * scale
|
||||
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
||||
return x
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
||||
x = torch.where(mask, x, mask_value)
|
||||
x = torch.nn.functional.softmax(x, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
@torch.profiler.record_function("softmax_function")
|
||||
def softmax_function(
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
mask_value: torch.Tensor,
|
||||
scale: float,
|
||||
softmax_dtype: torch.dtype,
|
||||
upcast: bool = True,
|
||||
):
|
||||
"""
|
||||
This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case
|
||||
needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting
|
||||
and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely
|
||||
inefficient.
|
||||
"""
|
||||
#assert x.size(-1) % 8 == 0
|
||||
if upcast:
|
||||
if mask is None:
|
||||
return upcast_softmax(x, scale, softmax_dtype)
|
||||
else:
|
||||
return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
|
||||
else:
|
||||
if mask is None:
|
||||
return torch.nn.functional.softmax(x, dim=-1)
|
||||
else:
|
||||
return masked_softmax(x, mask, mask_value)
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__()
|
||||
self.mask_value = None
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# KV caching and padding
|
||||
|
||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device)
|
||||
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device)
|
||||
|
||||
@torch.profiler.record_function("GPTBigCodeAttention._get_mask_value")
|
||||
def _get_mask_value(self, device, dtype):
|
||||
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
||||
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
|
||||
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
||||
return self.mask_value
|
||||
|
||||
@torch.profiler.record_function("GPTBigCodeAttention._attn")
|
||||
def _attn(self, query, key, value, attention_mask):
|
||||
softmax_dtype = torch.float32
|
||||
upcast = query.dtype != softmax_dtype
|
||||
|
||||
unscale = self.layer_idx + 1 if upcast else 1
|
||||
scale_factor = unscale**-1 / self.head_dim**0.5
|
||||
|
||||
# (batch_size, query_length, num_heads * head_dim)
|
||||
query_shape = query.shape
|
||||
batch_size = query_shape[0]
|
||||
key_length = key.size(-2)
|
||||
|
||||
key = key.transpose(-1, -2)
|
||||
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
||||
# -> (batch_size, query_length, num_heads, key_length)
|
||||
query_length = query_shape[1]
|
||||
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
||||
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
||||
# No copy needed for MQA 2, or when layer_past is provided.
|
||||
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
||||
|
||||
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
|
||||
if query.device.type == "cpu":
|
||||
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
|
||||
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
|
||||
# but the fix has not been released as of pytorch version 2.0.0.
|
||||
attn_weights.zero_()
|
||||
beta = 1
|
||||
else:
|
||||
beta = 0
|
||||
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)
|
||||
|
||||
attn_weights = softmax_function(
|
||||
attn_weights,
|
||||
attention_mask,
|
||||
None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype),
|
||||
unscale,
|
||||
softmax_dtype,
|
||||
upcast,
|
||||
)
|
||||
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
||||
|
||||
return attn_output
|
||||
|
||||
@torch.profiler.record_function("GPTBigCodeAttention._attn_flash")
|
||||
def _attn_flash(self, query, key, value, flash_params):
|
||||
|
||||
query_shape = query.shape
|
||||
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
||||
query = query.view(attn_shape)
|
||||
key = key.unsqueeze(1).expand(attn_shape)
|
||||
value = value.unsqueeze(1).expand(attn_shape)
|
||||
|
||||
sequence_lengths, padding_index, _, max_sequence_length = flash_params
|
||||
|
||||
# attn_output: (sum_seq_len, num_heads * head_dim)
|
||||
attn_output = flash_attn_unpadded_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sequence_lengths,
|
||||
sequence_lengths,
|
||||
max_sequence_length,
|
||||
max_sequence_length,
|
||||
0.0,
|
||||
softmax_scale=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
).view(query_shape)
|
||||
|
||||
return attn_output
|
||||
|
||||
@torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches")
|
||||
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
|
||||
|
||||
# Convert to standard KV cache format.
|
||||
if flash_params is not None:
|
||||
_, padding_index, batch_size, max_sequence_length = flash_params
|
||||
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
|
||||
return key_value, (current_kv_cache, max_sequence_length)
|
||||
|
||||
current_kv_cache = key_value
|
||||
|
||||
# Calculate dimensions and recover layer_past
|
||||
batch_size = current_kv_cache.size(0)
|
||||
query_length = current_kv_cache.size(-2)
|
||||
if layer_past is None:
|
||||
allocated_kv_cache, last_key_length = None, 0
|
||||
last_kv_cache = None
|
||||
key_length = query_length
|
||||
allocated_key_length = key_length
|
||||
else:
|
||||
allocated_kv_cache, last_key_length = layer_past
|
||||
last_kv_cache = allocated_kv_cache[:, :last_key_length]
|
||||
key_length = query_length + last_key_length
|
||||
allocated_key_length = allocated_kv_cache.size(-2)
|
||||
|
||||
padded_key_length = attention_mask.size(-1)
|
||||
|
||||
# Re-allocate kv cache and copy last value
|
||||
if padded_key_length > allocated_key_length:
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, padded_key_length, 2 * self.head_dim],
|
||||
dtype=current_kv_cache.dtype,
|
||||
device=current_kv_cache.device,
|
||||
)
|
||||
if layer_past is not None:
|
||||
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
|
||||
if padded_key_length > key_length:
|
||||
# Nans in `value` can propagate through the matrix multiplication,
|
||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
|
||||
|
||||
# Copy the new values.
|
||||
if padded_key_length > allocated_key_length or layer_past is not None:
|
||||
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||
# Use the merged KV cache.
|
||||
# Not needed when layer_past is None but frees some memory.
|
||||
key_value = padded_kv_cache
|
||||
|
||||
if allocated_kv_cache is None:
|
||||
allocated_kv_cache = current_kv_cache
|
||||
present = allocated_kv_cache, key_length
|
||||
return key_value, present
|
||||
|
||||
@torch.profiler.record_function("GPTBigCodeAttention.forward")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
flash_params: Optional[Tuple] = None
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
||||
|
||||
# present = (allocated_kv_cache, key_length)
|
||||
key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params)
|
||||
|
||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
if flash_params is None:
|
||||
attn_output=self._attn(query, key, value, attention_mask)
|
||||
else:
|
||||
attn_output=self._attn_flash(query, key, value, flash_params)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
|
||||
return attn_output, present
|
||||
|
||||
|
||||
class GPTBigCodeMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
|
||||
self.c_fc = nn.Linear(embed_dim, inner_dim)
|
||||
self.c_proj = nn.Linear(inner_dim, embed_dim)
|
||||
|
||||
@torch.profiler.record_function("GPTBigCodeMLP.forward")
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
||||
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
|
||||
return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh"))
|
||||
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
||||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device)
|
||||
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
||||
self.mlp = GPTBigCodeMLP(config)
|
||||
|
||||
@torch.profiler.record_function("GPTBigCodeBlock.forward")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
flash_params: Optional[Tuple] = None
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
|
||||
ai=self.ln_1(hidden_states)
|
||||
attn_output, present = self.attn(
|
||||
ai,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
flash_params=flash_params,
|
||||
)
|
||||
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
|
||||
hidden_states.add_(attn_output)
|
||||
|
||||
with torch.profiler.record_function("GPTBigCodeAttention.dummy"):
|
||||
pass
|
||||
with torch.profiler.record_function("GPTBigCodeAttention.ln"):
|
||||
ai=self.ln_2(hidden_states)
|
||||
ai=self.mlp(ai)
|
||||
with torch.profiler.record_function("GPTBigCodeAttention.residual"):
|
||||
hidden_states.add_(ai)
|
||||
return hidden_states, present
|
||||
|
||||
|
||||
|
||||
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = GPTBigCodeConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["GPTBigCodeBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, GPTBigCodeModel):
|
||||
module.bias.fill_(True).tril_()
|
||||
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
module.c_proj.weight.data.normal_(
|
||||
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
||||
)
|
||||
module.c_proj._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__(config)
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner)
|
||||
|
||||
self.flash_attention = True #config.flash_attention
|
||||
|
||||
if self.flash_attention:
|
||||
if flash_attn_unpadded_func is None:
|
||||
raise RuntimeError(
|
||||
"Flash Attention requires `flash_attn` and `einops`. "
|
||||
"To install, run `pip install flash-attn einops`."
|
||||
)
|
||||
|
||||
if self.inference_runner_type == InferenceRunnerType.NO_RUNNER:
|
||||
self.inference_runner = None
|
||||
else:
|
||||
from .inference_runner import GPTBigCodeInferenceRunner
|
||||
|
||||
self.inference_runner = GPTBigCodeInferenceRunner(config, self)
|
||||
|
||||
# Causal mask
|
||||
self.register_buffer(
|
||||
"bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device)
|
||||
)
|
||||
|
||||
#@torch.profiler.record_function("GPTBigCodeModel._get_causal_mask")
|
||||
def _get_causal_mask(self, padding_mask, query_length, key_length):
|
||||
# Self-attention mask.
|
||||
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
|
||||
|
||||
if padding_mask is not None:
|
||||
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
|
||||
dtype=torch.bool, device=attention_mask.device
|
||||
)
|
||||
pad = -key_length % 8
|
||||
if pad > 0:
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
|
||||
|
||||
# (batch_size, query_length, n_heads, key_length)
|
||||
return attention_mask.unsqueeze(2)
|
||||
|
||||
#@torch.profiler.record_function("GPTBigCodeModel.forward")
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.Tensor,
|
||||
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: torch.Tensor,
|
||||
) -> Tuple:
|
||||
if self.inference_runner is not None and past_key_values is not None:
|
||||
if self.config.validate_runner_input:
|
||||
assert past_key_values is not None
|
||||
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values)
|
||||
|
||||
batch_size, query_length = input_ids.shape
|
||||
|
||||
flash_attention = self.flash_attention and past_key_values is None
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][1]
|
||||
|
||||
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
|
||||
# TODO: Unpad earlier (input ids), support unpadded input?
|
||||
if flash_attention:
|
||||
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input(
|
||||
hidden_states, attention_mask
|
||||
)
|
||||
flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
|
||||
attention_mask=None
|
||||
else:
|
||||
key_length = past_length + query_length
|
||||
# Self-attention mask (padding + causal).
|
||||
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
|
||||
flash_params=None
|
||||
|
||||
presents = []
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
flash_params=flash_params
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
presents.append(outputs[1])
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if flash_attention:
|
||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
|
||||
|
||||
return hidden_states, presents
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__(config)
|
||||
meta=torch.device("meta")
|
||||
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta)
|
||||
|
||||
self.to_empty(device=device)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
#@torch.profiler.record_function("GPTBigCodeForCausalLM.forward")
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.Tensor,
|
||||
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: torch.Tensor,
|
||||
predict_all_tokens: bool=True,
|
||||
) -> Tuple:
|
||||
|
||||
hidden_states, presents=self.transformer(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids
|
||||
)
|
||||
#with torch.profiler.record_function("GPTBigCodeForCausalLM.head"):
|
||||
if not predict_all_tokens:
|
||||
# We only care about the last token.
|
||||
hidden_states = hidden_states[:, -1:]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
return lm_logits, presents
|
@ -13,25 +13,34 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch GPTBigCode model."""
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Any, Union, List
|
||||
from enum import IntEnum
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
import dropout_layer_norm
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
||||
GPTBigCodeConfig,
|
||||
InferenceRunnerType,
|
||||
)
|
||||
|
||||
class InferenceRunnerType(IntEnum):
|
||||
NO_RUNNER = 0
|
||||
# Use the inference runner without cuda graphs.
|
||||
BASE_RUNNER = 1
|
||||
# Use cuda graphs in the inference runner. Leave out the attention which has a variable shape.
|
||||
# This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models.
|
||||
PARTIAL_GRAPH = 2
|
||||
# Turn the whole model into a cuda graph. One graph for each sequence length.
|
||||
# Note: only useful for small batches and models, graphs take some time to generate, flaky.
|
||||
# Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100.
|
||||
FULL_GRAPH = 3
|
||||
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
@ -44,6 +53,8 @@ except ImportError:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def upcast_masked_softmax(
|
||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
|
||||
):
|
||||
@ -55,12 +66,6 @@ def upcast_masked_softmax(
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def upcast_masked_softmax_fused(
|
||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
|
||||
):
|
||||
return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
|
||||
|
||||
|
||||
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
|
||||
input_dtype = x.dtype
|
||||
x = x.to(softmax_dtype) * scale
|
||||
@ -69,21 +74,12 @@ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def upcast_softmax_fused(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
|
||||
return upcast_softmax(x, scale, softmax_dtype)
|
||||
|
||||
|
||||
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
||||
x = torch.where(mask, x, mask_value)
|
||||
x = torch.nn.functional.softmax(x, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def masked_softmax_fused(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
||||
return masked_softmax(x, mask, mask_value)
|
||||
|
||||
|
||||
def softmax_function(
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@ -91,78 +87,39 @@ def softmax_function(
|
||||
scale: float,
|
||||
softmax_dtype: torch.dtype,
|
||||
upcast: bool = True,
|
||||
fused_softmax: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case
|
||||
needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting
|
||||
and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely
|
||||
inefficient. TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
||||
Is it doable without writing 32 functions?
|
||||
inefficient.
|
||||
"""
|
||||
if fused_softmax is None:
|
||||
fused_softmax = x.size(-1) % 8 == 0
|
||||
#assert x.size(-1) % 8 == 0
|
||||
if upcast:
|
||||
if mask is None:
|
||||
return (upcast_softmax_fused if fused_softmax else upcast_softmax)(x, scale, softmax_dtype)
|
||||
return upcast_softmax(x, scale, softmax_dtype)
|
||||
else:
|
||||
return (upcast_masked_softmax_fused if fused_softmax else upcast_masked_softmax)(
|
||||
x, mask, mask_value, scale, softmax_dtype
|
||||
)
|
||||
return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype)
|
||||
else:
|
||||
if mask is None:
|
||||
return torch.nn.functional.softmax(x, dim=-1)
|
||||
else:
|
||||
return (masked_softmax_fused if fused_softmax else masked_softmax)(x, mask, mask_value)
|
||||
return masked_softmax(x, mask, mask_value)
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__()
|
||||
self.mask_value = None
|
||||
assert config.multi_query
|
||||
assert config.attention_softmax_in_fp32
|
||||
assert config.scale_attention_softmax_in_fp32
|
||||
|
||||
self.flash_attention = config.flash_attention
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.split_size = self.embed_dim
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale_attn_weights = config.scale_attn_weights
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.fused_softmax = config.fused_softmax
|
||||
|
||||
# KV caching and padding
|
||||
self.pre_allocate_kv_cache = (
|
||||
config.n_embd if config.pre_allocate_kv_cache is True else config.pre_allocate_kv_cache
|
||||
)
|
||||
pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length
|
||||
self._tuple_cache_format = self.pre_allocate_kv_cache or pad_key_length or self.flash_attention
|
||||
|
||||
if self.is_cross_attention:
|
||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
|
||||
|
||||
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
if self.flash_attention:
|
||||
if flash_attn_unpadded_func is None:
|
||||
raise RuntimeError(
|
||||
"Flash Attention requires `flash_attn` and `einops`. "
|
||||
"To install, run `pip install flash-attn einops`."
|
||||
)
|
||||
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device)
|
||||
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def _get_mask_value(self, device, dtype):
|
||||
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
||||
@ -171,14 +128,11 @@ class GPTBigCodeAttention(nn.Module):
|
||||
return self.mask_value
|
||||
|
||||
def _attn(self, query, key, value, attention_mask):
|
||||
dtype = query.dtype
|
||||
softmax_dtype = torch.float32
|
||||
upcast = dtype != softmax_dtype
|
||||
upcast = query.dtype != softmax_dtype
|
||||
|
||||
unscale = self.layer_idx + 1 if upcast else 1
|
||||
scale_factor = unscale**-1
|
||||
if self.scale_attn_weights:
|
||||
scale_factor /= self.head_dim**0.5
|
||||
scale_factor = unscale**-1 / self.head_dim**0.5
|
||||
|
||||
# (batch_size, query_length, num_heads * head_dim)
|
||||
query_shape = query.shape
|
||||
@ -212,16 +166,12 @@ class GPTBigCodeAttention(nn.Module):
|
||||
unscale,
|
||||
softmax_dtype,
|
||||
upcast,
|
||||
self.fused_softmax,
|
||||
)
|
||||
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
||||
|
||||
return attn_output, attn_weights
|
||||
return attn_output
|
||||
|
||||
def _attn_flash(self, query, key, value, attention_mask):
|
||||
def _attn_flash(self, query, key, value, flash_params):
|
||||
|
||||
query_shape = query.shape
|
||||
attn_shape = query_shape[0], self.num_heads, self.head_dim
|
||||
@ -229,7 +179,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
key = key.unsqueeze(1).expand(attn_shape)
|
||||
value = value.unsqueeze(1).expand(attn_shape)
|
||||
|
||||
sequence_lengths, padding_index, _, max_sequence_length = attention_mask
|
||||
sequence_lengths, padding_index, _, max_sequence_length = flash_params
|
||||
|
||||
# attn_output: (sum_seq_len, num_heads * head_dim)
|
||||
attn_output = flash_attn_unpadded_func(
|
||||
@ -240,31 +190,21 @@ class GPTBigCodeAttention(nn.Module):
|
||||
sequence_lengths,
|
||||
max_sequence_length,
|
||||
max_sequence_length,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1,
|
||||
0.0,
|
||||
softmax_scale=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
).view(query_shape)
|
||||
|
||||
return attn_output, None
|
||||
return attn_output
|
||||
|
||||
def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length):
|
||||
batch_size = kv_cache.size(-1)
|
||||
assert not self.training
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
|
||||
)
|
||||
allocated_kv_cache[:, :key_length].copy_(kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||
return allocated_kv_cache, padded_kv_cache
|
||||
|
||||
def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask):
|
||||
flash_attention = self.flash_attention and layer_past is None
|
||||
def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params):
|
||||
|
||||
# Convert to standard KV cache format.
|
||||
if flash_attention and use_cache:
|
||||
_, padding_index, batch_size, max_sequence_length = attention_mask
|
||||
if flash_params is not None:
|
||||
_, padding_index, batch_size, max_sequence_length = flash_params
|
||||
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
|
||||
else:
|
||||
return key_value, (current_kv_cache, max_sequence_length)
|
||||
|
||||
current_kv_cache = key_value
|
||||
|
||||
# Calculate dimensions and recover layer_past
|
||||
@ -281,38 +221,33 @@ class GPTBigCodeAttention(nn.Module):
|
||||
key_length = query_length + last_key_length
|
||||
allocated_key_length = allocated_kv_cache.size(-2)
|
||||
|
||||
padded_key_length = key_length if flash_attention else attention_mask.size(-1)
|
||||
allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length)
|
||||
padded_key_length = attention_mask.size(-1)
|
||||
|
||||
# Re-allocate kv cache and copy last value
|
||||
if allocate_key_length > allocated_key_length:
|
||||
if padded_key_length > allocated_key_length:
|
||||
allocated_kv_cache = torch.empty(
|
||||
[batch_size, allocate_key_length, 2 * self.head_dim],
|
||||
[batch_size, padded_key_length, 2 * self.head_dim],
|
||||
dtype=current_kv_cache.dtype,
|
||||
device=current_kv_cache.device,
|
||||
)
|
||||
if layer_past is not None:
|
||||
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
|
||||
if allocate_key_length > key_length:
|
||||
if padded_key_length > key_length:
|
||||
# Nans in `value` can propagate through the matrix multiplication,
|
||||
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
|
||||
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
|
||||
|
||||
# Copy the new values.
|
||||
if allocate_key_length > allocated_key_length or layer_past is not None:
|
||||
if padded_key_length > allocated_key_length or layer_past is not None:
|
||||
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
|
||||
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
|
||||
if not flash_attention:
|
||||
# Use the merged KV cache.
|
||||
# Not needed when layer_past is None but frees some memory.
|
||||
key_value = padded_kv_cache
|
||||
|
||||
if use_cache:
|
||||
if allocated_kv_cache is None:
|
||||
allocated_kv_cache = current_kv_cache
|
||||
present = allocated_kv_cache, key_length
|
||||
else:
|
||||
present = None
|
||||
return key_value, present
|
||||
|
||||
def forward(
|
||||
@ -320,118 +255,63 @@ class GPTBigCodeAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
]:
|
||||
flash_attention = self.flash_attention and layer_past is None
|
||||
flash_params: Optional[Tuple] = None
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1)
|
||||
|
||||
if self._tuple_cache_format:
|
||||
# present = (allocated_kv_cache, key_length)
|
||||
key_value, present = self._merge_kv_caches(key_value, use_cache, layer_past, attention_mask)
|
||||
else:
|
||||
# present = key_value
|
||||
if layer_past is not None:
|
||||
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||
present = key_value if use_cache else None
|
||||
key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params)
|
||||
|
||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)(
|
||||
query, key, value, attention_mask
|
||||
)
|
||||
if flash_params is None:
|
||||
attn_output=self._attn(query, key, value, attention_mask)
|
||||
else:
|
||||
attn_output=self._attn_flash(query, key, value, flash_params)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
if flash_attention:
|
||||
raise ValueError("`output_attentions` is not supported with Flash Attention.")
|
||||
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
||||
attn_weights = attn_weights.transpose(1, 2)
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
return attn_output, present
|
||||
|
||||
|
||||
class GPTBigCodeMLP(nn.Module):
|
||||
def __init__(self, intermediate_size, config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.c_fc = nn.Linear(embed_dim, intermediate_size)
|
||||
self.c_proj = nn.Linear(intermediate_size, embed_dim)
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.dropout = nn.Dropout(config.resid_pdrop)
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim
|
||||
self.c_fc = nn.Linear(embed_dim, inner_dim)
|
||||
self.c_proj = nn.Linear(inner_dim, embed_dim)
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
||||
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh"))
|
||||
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
if config.add_cross_attention:
|
||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||
|
||||
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
|
||||
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
||||
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device)
|
||||
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device)
|
||||
self.mlp = GPTBigCodeMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.Tensor]],
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
]:
|
||||
if encoder_hidden_states is not None or encoder_attention_mask is not None:
|
||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
hidden_states,
|
||||
flash_params: Optional[Tuple] = None
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
attn_output, present = self.attn(
|
||||
self.ln_1(hidden_states),
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
flash_params=flash_params,
|
||||
)
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
hidden_states.add_(attn_output)
|
||||
hidden_states.add_(self.mlp(self.ln_2(hidden_states)))
|
||||
return hidden_states, present
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
@ -442,7 +322,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = GPTBigCodeConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["GPTBigCodeBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
@ -450,7 +330,9 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
|
||||
if isinstance(module, GPTBigCodeModel):
|
||||
module.bias.fill_(True).tril_()
|
||||
elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)):
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
@ -475,35 +357,27 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, GPTBigCodeModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__(config)
|
||||
assert config.multi_query
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
if config.add_cross_attention:
|
||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner)
|
||||
|
||||
self.pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length
|
||||
self._tuple_cache_format = config.pre_allocate_kv_cache or self.pad_key_length or config.flash_attention
|
||||
self.inference_runner_type = InferenceRunnerType(config.inference_runner)
|
||||
self.flash_attention = True #config.flash_attention
|
||||
|
||||
self.flash_attention = config.flash_attention
|
||||
if self.flash_attention:
|
||||
if flash_attn_unpadded_func is None:
|
||||
raise RuntimeError(
|
||||
"Flash Attention requires `flash_attn` and `einops`. "
|
||||
"To install, run `pip install flash-attn einops`."
|
||||
)
|
||||
|
||||
if self.inference_runner_type == InferenceRunnerType.NO_RUNNER:
|
||||
self.inference_runner = None
|
||||
@ -512,23 +386,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
|
||||
self.inference_runner = GPTBigCodeInferenceRunner(config, self)
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
# Causal mask
|
||||
self.register_buffer(
|
||||
"bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
|
||||
"bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device)
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
def _get_causal_mask(self, padding_mask, query_length, key_length):
|
||||
# Self-attention mask.
|
||||
attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
|
||||
@ -537,306 +399,105 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
attention_mask = attention_mask * padding_mask.unsqueeze(1).to(
|
||||
dtype=torch.bool, device=attention_mask.device
|
||||
)
|
||||
pad = -key_length % 8
|
||||
if pad > 0:
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
|
||||
|
||||
# MQA models: (batch_size, query_length, n_heads, key_length)
|
||||
# MHA models: (batch_size, n_heads, query_length, key_length)
|
||||
# (batch_size, query_length, n_heads, key_length)
|
||||
return attention_mask.unsqueeze(2)
|
||||
|
||||
def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device):
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.to(device)
|
||||
elif padding_mask is not None and padding_mask.ndim == 2:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = padding_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(padding_mask == 0, 1)
|
||||
if key_length > query_length:
|
||||
position_ids = position_ids[:, key_length - query_length : key_length :]
|
||||
else:
|
||||
position_ids = torch.arange(key_length - query_length, key_length, dtype=torch.long, device=device)
|
||||
return position_ids.view(-1, query_length)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
input_ids: torch.Tensor,
|
||||
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
position_ids: torch.Tensor,
|
||||
) -> Tuple:
|
||||
if self.inference_runner is not None and past_key_values is not None:
|
||||
if self.config.validate_runner_input:
|
||||
assert input_ids is not None
|
||||
assert past_key_values is not None
|
||||
assert attention_mask is not None
|
||||
assert token_type_ids is None
|
||||
assert position_ids is not None
|
||||
assert inputs_embeds is None
|
||||
assert encoder_hidden_states is None
|
||||
assert encoder_attention_mask is None
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
assert use_cache is True
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
assert output_attentions is False
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
assert output_hidden_states is False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
assert return_dict is True
|
||||
return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = self.config.use_cache if use_cache is None else use_cache
|
||||
return_dict = self.config.use_return_dict if return_dict is None else return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
if inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
input_shape = input_ids.shape
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size, query_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.shape[:-1]
|
||||
inputs_embeds = inputs_embeds.view(-1, input_shape[-2:])
|
||||
batch_size, query_length = inputs_embeds.shape[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
|
||||
flash_attention = self.flash_attention and past_key_values is None
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
elif self._tuple_cache_format:
|
||||
past_length = past_key_values[0][1]
|
||||
else:
|
||||
past_length = past_key_values[0].size(-2)
|
||||
key_length = past_length + query_length
|
||||
past_length = past_key_values[0][1]
|
||||
|
||||
position_ids = self._get_position_ids(position_ids, attention_mask, query_length, key_length, input_ids.device)
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, query_length)
|
||||
|
||||
if not flash_attention:
|
||||
# Self-attention mask (padding + causal).
|
||||
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
|
||||
if self.pad_key_length:
|
||||
pad = -key_length % 8
|
||||
if pad > 0:
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False)
|
||||
|
||||
if encoder_hidden_states is not None or encoder_attention_mask is not None:
|
||||
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
|
||||
# TODO: Unpad earlier (input ids), support unpadded input?
|
||||
if flash_attention:
|
||||
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input(
|
||||
hidden_states, attention_mask
|
||||
)
|
||||
# Pass the required parameters through the attention_mask argument
|
||||
attention_mask = (sequence_lengths, padding_index, batch_size, max_sequence_length)
|
||||
|
||||
presents = [] if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
)
|
||||
flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length)
|
||||
attention_mask=None
|
||||
else:
|
||||
key_length = past_length + query_length
|
||||
# Self-attention mask (padding + causal).
|
||||
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
|
||||
flash_params=None
|
||||
|
||||
presents = []
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
flash_params=flash_params
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache:
|
||||
presents.append(outputs[1])
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if flash_attention:
|
||||
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)
|
||||
|
||||
hidden_states = hidden_states.view(input_shape + (hidden_states.size(-1),))
|
||||
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
return hidden_states, presents
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")):
|
||||
super().__init__(config)
|
||||
self.transformer = GPTBigCodeModel(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.predict_last_token = config.predict_last_token
|
||||
meta=torch.device("meta")
|
||||
self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta)
|
||||
|
||||
self.to_empty(device=device)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": kwargs.get("position_ids", None),
|
||||
"attention_mask": kwargs.get("attention_mask", None),
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
*,
|
||||
input_ids: torch.Tensor,
|
||||
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
position_ids: torch.Tensor,
|
||||
predict_all_tokens: bool=True,
|
||||
) -> Tuple:
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
hidden_states, presents=self.transformer(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
position_ids=position_ids
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
if self.predict_last_token and not self.training:
|
||||
if not predict_all_tokens:
|
||||
# We only care about the last token.
|
||||
hidden_states = hidden_states[:, -1:]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
return lm_logits, presents
|
||||
|
@ -1,13 +1,12 @@
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import GPTBigCodeConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
||||
InferenceRunnerType,
|
||||
)
|
||||
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, softmax_function
|
||||
from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeBlock, softmax_function
|
||||
|
||||
|
||||
def _align_tensor(x):
|
||||
@ -291,7 +290,7 @@ class GPTBigCodeInferenceRunner:
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
past_key_values: Union[List[torch.Tensor], int],
|
||||
) -> BaseModelOutputWithPastAndCrossAttentions:
|
||||
) -> Tuple:
|
||||
batch_size, query_length = input_ids.shape
|
||||
assert query_length == 1
|
||||
if self.batch_size is None:
|
||||
@ -333,7 +332,4 @@ class GPTBigCodeInferenceRunner:
|
||||
else:
|
||||
hidden_states = self._forward(key_length)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=key_length,
|
||||
)
|
||||
return hidden_states, key_length
|
||||
|
@ -449,6 +449,49 @@ class FlashCausalLM(Model):
|
||||
pre_allocate_past_size=pre_allocate_past_size,
|
||||
)
|
||||
|
||||
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, use_cache: Optional[torch.dtype]):
|
||||
diff = max_input_length - max(batch.input_lengths)
|
||||
for i in range(len(batch)):
|
||||
batch.input_lengths[i] += diff
|
||||
batch.prefix_offsets[i] = 0
|
||||
batch.read_offsets[i] = 0
|
||||
batch.all_input_ids[i] = batch.all_input_ids[i] + [self.tokenizer.pad_token_id] * diff if diff >= 0 else batch.all_input_ids[i][:diff]
|
||||
# TODO: Bug!?!
|
||||
batch.stopping_criterias[i].current_tokens += diff
|
||||
|
||||
if use_cache:
|
||||
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
|
||||
batch.input_ids.fill_(self.tokenizer.pad_token_id)
|
||||
batch.position_ids += diff
|
||||
batch.cu_seqlens += diff * batch.cu_seqlens_q
|
||||
# TODO: Bug!?!
|
||||
batch.max_seqlen += batch.max_seqlen + diff*len(batch)
|
||||
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(self.tokenizer.pad_token_id)
|
||||
|
||||
batch.past_key_values = batch.past_key_values = torch.randn(
|
||||
(
|
||||
batch.past_key_values.shape[0],
|
||||
batch.past_key_values.shape[1] + len(batch.requests),
|
||||
*batch.past_key_values.shape[2:],
|
||||
), device=batch.past_key_values.device, dtype= batch.past_key_values.dtype
|
||||
)
|
||||
else:
|
||||
batch.max_seqlen = max(batch.input_lengths)
|
||||
batch.all_input_ids_tensor=[]
|
||||
|
||||
batch.input_ids = torch.tensor(
|
||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
|
||||
)
|
||||
batch.position_ids = torch.tensor(
|
||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
|
||||
)
|
||||
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
|
||||
batch.past_key_values=None
|
||||
|
||||
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: FlashCausalLMBatch
|
||||
|
@ -223,6 +223,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
|
135
server/text_generation_server/models/gpt_bigcode.py
Normal file
135
server/text_generation_server/models/gpt_bigcode.py
Normal file
@ -0,0 +1,135 @@
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer
|
||||
from typing import Optional, Type
|
||||
|
||||
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeForCausalLM
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BigcodeBatch(VectorizedCausalLMBatch):
|
||||
kv_cache_seq_dim: int = 1
|
||||
|
||||
def _filter_kv_caches(self, keep_indices, sequence_slice):
|
||||
if self.past_key_values is not None:
|
||||
for layer_kv, _ in self.past_key_values:
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
layer_kv.data = layer_kv[keep_indices, sequence_slice]
|
||||
|
||||
@classmethod
|
||||
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
|
||||
device = batches[0].input_ids.device
|
||||
batch_size = sum([len(batch.requests) for batch in batches])
|
||||
|
||||
for batch in batches:
|
||||
if batch.past_key_values is None:
|
||||
raise ValueError("Only concatenate prefilled batches")
|
||||
|
||||
past_key_values = []
|
||||
for kv_caches in zip(*(batch.past_key_values for batch in batches)):
|
||||
key_values, seq_lengths = zip(*kv_caches)
|
||||
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths))
|
||||
|
||||
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values))
|
||||
allocate_seq_len += - allocate_seq_len % 8
|
||||
|
||||
kv_cache = torch.empty(
|
||||
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device
|
||||
)
|
||||
for key_value, start_index, end_index, left_index in zip(
|
||||
key_values,
|
||||
start_indices,
|
||||
end_indices,
|
||||
left_indices,
|
||||
):
|
||||
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value)
|
||||
# Set padding to zero to avoid propagating nans.
|
||||
kv_cache[start_index:end_index, :left_index].fill_(0)
|
||||
kv_cache[start_index:end_index, max_input_length:].fill_(0)
|
||||
past_key_values.append((kv_cache, max_input_length))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
class BigcodeCausalLM(VectorizedCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
)
|
||||
model = GPTBigCodeForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
)
|
||||
tokenizer.pad_token_id = (
|
||||
model.config.pad_token_id
|
||||
if model.config.pad_token_id is not None
|
||||
else model.config.eos_token_id
|
||||
)
|
||||
|
||||
super(VectorizedCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[BigcodeBatch]:
|
||||
return BigcodeBatch
|
||||
|
||||
def forward(self, batch:BigcodeBatch):
|
||||
key_length = batch.max_input_length
|
||||
query_length = key_length if batch.past_key_values is None else 1
|
||||
input_ids = batch.input_ids[:, key_length - query_length : key_length]
|
||||
# Model Forward
|
||||
logits, batch.past_key_values = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=batch.attention_mask[:, :key_length],
|
||||
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
||||
past_key_values=batch.past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
next_token_ids, logprobs = batch.next_token_chooser(
|
||||
input_ids, logits, batch.details
|
||||
)
|
||||
# Update batch
|
||||
# TODO: Why do we need all input ids?
|
||||
batch.input_ids[:, key_length].copy_(next_token_ids)
|
||||
batch.input_lengths = [length + 1 for length in batch.input_lengths]
|
||||
batch.max_input_length += 1
|
||||
|
||||
return next_token_ids, logprobs
|
||||
|
||||
def mock_kv_cache(self, batch: BigcodeBatch, dtype:Optional[torch.dtype]):
|
||||
allocate_length=batch.max_input_length+-batch.max_input_length%8
|
||||
return [(torch.empty(
|
||||
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
|
||||
dtype=dtype,
|
||||
device=batch.input_ids.device,
|
||||
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)]
|
149
server/text_generation_server/models/gpt_bigcode2.py
Normal file
149
server/text_generation_server/models/gpt_bigcode2.py
Normal file
@ -0,0 +1,149 @@
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer
|
||||
from typing import Optional, Type
|
||||
|
||||
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import GPTBigCodeForCausalLM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Bigcode2Batch(VectorizedCausalLMBatch):
|
||||
kv_cache_seq_dim: int = 1
|
||||
pad_key_length_to_multiple:int=8
|
||||
|
||||
# Prefill the attention mask for padded key length.
|
||||
attention_mask_fill_value=False
|
||||
|
||||
def _filter_kv_caches(self, keep_indices, sequence_slice):
|
||||
if self.past_key_values is not None:
|
||||
for layer_kv, _ in self.past_key_values:
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
layer_kv.data = layer_kv[keep_indices, sequence_slice]
|
||||
|
||||
@classmethod
|
||||
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
|
||||
device = batches[0].input_ids.device
|
||||
batch_size = sum([len(batch.requests) for batch in batches])
|
||||
|
||||
for batch in batches:
|
||||
if batch.past_key_values is None:
|
||||
raise ValueError("Only concatenate prefilled batches")
|
||||
|
||||
past_key_values = []
|
||||
for kv_caches in zip(*(batch.past_key_values for batch in batches)):
|
||||
key_values, seq_lengths = zip(*kv_caches)
|
||||
assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths))
|
||||
|
||||
allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values))
|
||||
allocate_seq_len += - allocate_seq_len % batches[0].pad_key_length_to_multiple
|
||||
|
||||
kv_cache = torch.empty(
|
||||
(batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device
|
||||
)
|
||||
for key_value, start_index, end_index, left_index in zip(
|
||||
key_values,
|
||||
start_indices,
|
||||
end_indices,
|
||||
left_indices,
|
||||
):
|
||||
kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value)
|
||||
# Set padding to zero to avoid propagating nans.
|
||||
kv_cache[start_index:end_index, :left_index].fill_(0)
|
||||
kv_cache[start_index:end_index, max_input_length:].fill_(0)
|
||||
past_key_values.append((kv_cache, max_input_length))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
class Bigcode2CausalLM(VectorizedCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
)
|
||||
model = GPTBigCodeForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
)
|
||||
tokenizer.pad_token_id = (
|
||||
model.config.pad_token_id
|
||||
if model.config.pad_token_id is not None
|
||||
else model.config.eos_token_id
|
||||
)
|
||||
|
||||
super(VectorizedCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[Bigcode2Batch]:
|
||||
return Bigcode2Batch
|
||||
|
||||
def forward(self, batch:Bigcode2Batch):
|
||||
key_length = batch.max_input_length
|
||||
if batch.past_key_values is None:
|
||||
# Prefill (flash attn, unpadded key length)
|
||||
batch.pad_key_length_to_multiple=self.model.pad_key_length_to_multiple
|
||||
padded_key_length=key_length
|
||||
query_length=key_length
|
||||
else:
|
||||
# Decode (fused attn, padded key length)
|
||||
batch.attention_mask[:, key_length-1].fill_(True)
|
||||
padded_key_length=key_length+-key_length%batch.pad_key_length_to_multiple
|
||||
query_length=1
|
||||
|
||||
input_ids = batch.input_ids[:, key_length - query_length : key_length]
|
||||
# Model Forward
|
||||
logits, batch.past_key_values = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=batch.attention_mask[:, :padded_key_length],
|
||||
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
||||
past_key_values=batch.past_key_values,
|
||||
key_length=key_length,
|
||||
predict_all_tokens=batch.details
|
||||
)
|
||||
next_token_ids, logprobs = batch.next_token_chooser(
|
||||
input_ids, logits, batch.details
|
||||
)
|
||||
# Update batch
|
||||
# TODO: Why do we need all input ids?
|
||||
batch.input_ids[:, key_length].copy_(next_token_ids)
|
||||
batch.input_lengths = [length + 1 for length in batch.input_lengths]
|
||||
batch.max_input_length += 1
|
||||
|
||||
return next_token_ids, logprobs
|
||||
|
||||
def mock_kv_cache(self, batch: Bigcode2Batch, dtype:Optional[torch.dtype]):
|
||||
allocate_length=batch.max_input_length+-batch.max_input_length%batch.pad_key_length_to_multiple
|
||||
return [(torch.empty(
|
||||
[len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
|
||||
dtype=dtype,
|
||||
device=batch.input_ids.device,
|
||||
),batch.max_input_length-1) for _ in range(self.model.config.n_layer)]
|
@ -58,6 +58,9 @@ class VectorizedCausalLMBatch(Batch):
|
||||
|
||||
kv_cache_seq_dim: int = 2
|
||||
|
||||
# Prefill the attention mask for the generated tokens
|
||||
attention_mask_fill_value=True
|
||||
|
||||
# TODO: Get from requests (should these be lists?)
|
||||
details: bool = os.environ.get("RETURN_DETAILS") is not None
|
||||
generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None
|
||||
@ -113,7 +116,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"])
|
||||
attention_mask[:, max_input_length:].fill_(1)
|
||||
attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value)
|
||||
|
||||
position_ids = attention_mask.cumsum(-1).sub_(1)
|
||||
position_ids[:, :max_input_length].relu_()
|
||||
@ -268,7 +271,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||
attention_mask[:, :max_input_length].fill_(0)
|
||||
attention_mask[:, max_input_length:].fill_(1)
|
||||
attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value)
|
||||
|
||||
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
|
||||
# TODO : only needed for prefill
|
||||
@ -293,7 +296,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
kv_cache_seq_dim = batches[0].kv_cache_seq_dim
|
||||
past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices)
|
||||
past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices, max_input_length)
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
@ -314,7 +317,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices):
|
||||
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length):
|
||||
device = batches[0].input_ids.device
|
||||
batch_size = sum([len(batch.requests) for batch in batches])
|
||||
|
||||
@ -355,29 +358,23 @@ class VectorizedCausalLMBatch(Batch):
|
||||
else batch.past_key_values[i][j]
|
||||
for batch in batches
|
||||
]
|
||||
# Generally `max_input_length`, unless the model allocates more than needed.
|
||||
right_indices = [
|
||||
left_index + tensor.size(kv_cache_seq_dim)
|
||||
for tensor, left_index in zip(tensors_to_merge, left_indices)
|
||||
]
|
||||
combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:])
|
||||
combined_shape[kv_cache_seq_dim] = max(right_indices)
|
||||
combined_shape[kv_cache_seq_dim] = max_input_length
|
||||
# Set to zero to avoid propagating nans in padded values.
|
||||
kv_cache = torch.zeros(
|
||||
combined_shape, dtype=tensors_to_merge[0].dtype, device=device
|
||||
)
|
||||
for tensor, start_index, end_index, left_index, right_index in zip(
|
||||
for tensor, start_index, end_index, left_index in zip(
|
||||
tensors_to_merge,
|
||||
start_indices,
|
||||
end_indices,
|
||||
left_indices,
|
||||
right_indices,
|
||||
):
|
||||
kv_cache[
|
||||
[
|
||||
slice(start_index, end_index),
|
||||
*(slice(None) for _ in range(1, kv_cache_seq_dim)),
|
||||
slice(left_index, right_index),
|
||||
slice(left_index, max_input_length),
|
||||
]
|
||||
].copy_(tensor)
|
||||
if kv_format is None:
|
||||
@ -398,13 +395,11 @@ class VectorizedCausalLM(Model):
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
decode_buffer: int = 3,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
# TODO: Choose dtype (fp16?)
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -415,26 +410,26 @@ class VectorizedCausalLM(Model):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
load_in_8bit=quantize,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=True,
|
||||
).eval()
|
||||
)
|
||||
tokenizer.pad_token_id = (
|
||||
self.model.config.pad_token_id
|
||||
if self.model.config.pad_token_id is not None
|
||||
else self.model.config.eos_token_id
|
||||
model.config.pad_token_id
|
||||
if model.config.pad_token_id is not None
|
||||
else model.config.eos_token_id
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=decode_buffer,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -446,6 +441,30 @@ class VectorizedCausalLM(Model):
|
||||
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(self, batch:VectorizedCausalLMBatch):
|
||||
key_length = batch.max_input_length
|
||||
query_length = key_length if batch.past_key_values is None else 1
|
||||
input_ids = batch.input_ids[:, key_length - query_length : key_length]
|
||||
# Model Forward
|
||||
logits, batch.past_key_values, *_ = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=batch.attention_mask[:, :key_length],
|
||||
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
||||
past_key_values=batch.past_key_values,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
next_token_ids, logprobs = batch.next_token_chooser(
|
||||
input_ids, logits, batch.details
|
||||
)
|
||||
# Update batch
|
||||
# TODO: Why do we need all input ids?
|
||||
batch.input_ids[:, key_length].copy_(next_token_ids)
|
||||
batch.input_lengths = [length + 1 for length in batch.input_lengths]
|
||||
batch.max_input_length += 1
|
||||
|
||||
return next_token_ids, logprobs
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: VectorizedCausalLMBatch
|
||||
@ -453,20 +472,9 @@ class VectorizedCausalLM(Model):
|
||||
key_length = batch.max_input_length
|
||||
if key_length > batch.input_ids.size(1):
|
||||
raise RuntimeError("Cannot generate more than `max_tokens`.")
|
||||
is_prefill = batch.past_key_values is None
|
||||
|
||||
query_length = key_length if batch.past_key_values is None else 1
|
||||
input_ids = batch.input_ids[:, key_length - query_length : key_length]
|
||||
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=batch.attention_mask[:, :key_length],
|
||||
position_ids=batch.position_ids[:, key_length - query_length : key_length],
|
||||
past_key_values=batch.past_key_values,
|
||||
)
|
||||
# TODO: Post-processing
|
||||
next_token_ids, logprobs = batch.next_token_chooser(
|
||||
input_ids, outputs.logits, batch.details
|
||||
)
|
||||
next_token_ids, logprobs = self.forward(batch)
|
||||
|
||||
if batch.generate_stream:
|
||||
# TODO: self.decode_token, offsets?
|
||||
@ -479,7 +487,6 @@ class VectorizedCausalLM(Model):
|
||||
.squeeze(1)
|
||||
.tolist()
|
||||
)
|
||||
is_prefill = batch.past_key_values is None
|
||||
if is_prefill:
|
||||
prefill_token_ids = batch.input_ids[:, :key_length].tolist()
|
||||
prefill_logprobs = (
|
||||
@ -491,7 +498,8 @@ class VectorizedCausalLM(Model):
|
||||
for prefill_token_ids_, prefill_logprobs_, input_length in zip(
|
||||
prefill_token_ids, prefill_logprobs, batch.input_lengths
|
||||
):
|
||||
prefill_token_ids_ = prefill_token_ids_[-input_length:]
|
||||
# Input length has already been incremented so we subtract 1.
|
||||
prefill_token_ids_ = prefill_token_ids_[-(input_length-1):]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids_,
|
||||
clean_up_tokenization_spaces=False,
|
||||
@ -505,13 +513,6 @@ class VectorizedCausalLM(Model):
|
||||
)
|
||||
)
|
||||
|
||||
# Update batch
|
||||
# TODO: Why do we need all input ids?
|
||||
batch.input_ids[:, key_length].copy_(next_token_ids)
|
||||
batch.past_key_values = outputs.past_key_values
|
||||
batch.input_lengths = [length + 1 for length in batch.input_lengths]
|
||||
batch.max_input_length += 1
|
||||
|
||||
# TODO: Vectorize some of this?
|
||||
|
||||
generations: List[Generation] = []
|
||||
@ -556,3 +557,24 @@ class VectorizedCausalLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
return generations, next_batch
|
||||
|
||||
def mock_kv_cache(self, batch: VectorizedCausalLMBatch, dtype:Optional[torch.dtype]):
|
||||
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
|
||||
if not isinstance(self.model, GPTBigCodeForCausalLM):
|
||||
raise NotImplementedError()
|
||||
return [torch.empty(
|
||||
[len(batch), batch.max_input_length-1, 2 * self.model.config.n_embd // self.model.config.n_head],
|
||||
dtype=dtype,
|
||||
device=batch.input_ids.device,
|
||||
) for _ in range(self.model.config.n_layer)]
|
||||
|
||||
def fast_forward(self, batch: VectorizedCausalLMBatch, max_input_length: int, cache_dtype:Optional[torch.dtype]):
|
||||
diff=max_input_length-batch.max_input_length
|
||||
batch.input_ids[:, batch.max_input_length:max_input_length].fill_(self.tokenizer.pad_token_id)
|
||||
batch.input_lengths = [length + diff for length in batch.input_lengths]
|
||||
batch.max_input_length += diff
|
||||
for stopping_criteria in batch.stopping_criterias:
|
||||
stopping_criteria.current_tokens+=diff
|
||||
batch.past_key_values = None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user