From 0921fe6a2a200a37053942f56056bc5f7f4805f1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 24 May 2023 19:51:53 -0400 Subject: [PATCH] Optimized model --- .../text_generation_server/models/__init__.py | 19 +- .../custom_modeling/gpt_bigcode2_modeling.py | 406 ++++++++++++ .../custom_modeling/gpt_bigcode3_modeling.py | 520 +++++++++++++++ .../custom_modeling/gpt_bigcode_modeling.py | 623 ++++-------------- .../custom_modeling/inference_runner.py | 12 +- .../models/flash_causal_lm.py | 43 ++ .../models/flash_santacoder.py | 1 + .../models/gpt_bigcode.py | 135 ++++ .../models/gpt_bigcode2.py | 149 +++++ .../models/vectorized_causal_lm.py | 114 ++-- 10 files changed, 1485 insertions(+), 537 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py create mode 100644 server/text_generation_server/models/gpt_bigcode.py create mode 100644 server/text_generation_server/models/gpt_bigcode2.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d2032353..7fcb4c7c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -17,6 +17,8 @@ from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded +from text_generation_server.models.gpt_bigcode import BigcodeCausalLM +from text_generation_server.models.gpt_bigcode2 import Bigcode2CausalLM try: if torch.cuda.is_available() and os.environ.get("NO_FLASH_ATTENTION") is None: @@ -91,10 +93,25 @@ torch.backends.cudnn.allow_tf32 = True # Disable gradients torch.set_grad_enabled(False) +model_type_map={ + "flash":FlashSantacoder, + "santa":SantaCoder, + "causal":CausalLM, + "vector":VectorizedCausalLM, + "bigcode":BigcodeCausalLM, + "bigcode2":Bigcode2CausalLM, +} + def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str] ) -> Model: + model_type=os.environ.get("MODEL_TYPE") + if model_type is not None: + if model_type not in model_type_map: + raise NotImplementedError(model_type) + return model_type_map[model_type](model_id, revision, quantize=quantize) + if "facebook/galactica" in model_id: if sharded: return GalacticaSharded(model_id, revision, quantize=quantize) @@ -167,8 +184,6 @@ def get_model( raise ValueError("sharded is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - if os.environ.get("VECTORIZED_LM") is not None: - return VectorizedCausalLM(model_id, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py new file mode 100644 index 00000000..98beec92 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py @@ -0,0 +1,406 @@ +# coding=utf-8 +# Copyright 2023 The Bigcode team and HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTBigCode model.""" +import math +from typing import Optional, Tuple, Any, Union, List +from enum import IntEnum + +import torch +import torch.utils.checkpoint +from torch import nn + +from dropout_layer_norm import dropout_layer_norm +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( + GPTBigCodeConfig, +) + +class FastLayerNorm(nn.LayerNorm): + # TODO: Validate dimension + def forward(self, hidden_states, residual=None): + return dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + + +class FastLinear(nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.addmm(self.bias, input, self.weight) + +class FastLinearNoBias(nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.mm(input, self.weight) + +logger = logging.get_logger(__name__) + +@torch.jit.script +def upcast_masked_softmax( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float +): + input_dtype = x.dtype + x = x.to(torch.float32) * scale + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + +@torch.jit.script +def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1) + return x + +class GPTBigCodeAttention(nn.Module): + def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.layer_idx = layer_idx + # Note: Does not support module dtype conversion. + self.register_buffer("mask_value", torch.empty((), dtype=torch.float32, device="meta")) + + self.c_attn = FastLinear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device="meta") + self.c_proj = FastLinear(self.embed_dim, self.embed_dim, dtype=dtype, device="meta") + + def prefill( + self, + hidden_states: torch.Tensor, + sequence_lengths, + key_length:int, + ) -> Tuple[torch.Tensor, Any]: + hidden_shape = hidden_states.shape + query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) + query = query.view(hidden_shape[0], self.num_heads, self.head_dim) + key, value = key_value.unsqueeze(1).expand(hidden_shape[0], self.num_heads, 2*self.head_dim).split((self.head_dim, self.head_dim), dim=-1) + + # attn_output: (sum_seq_len, num_heads * head_dim) + attn_output = flash_attn_unpadded_func( + query, + key, + value, + sequence_lengths, + sequence_lengths, + key_length, + key_length, + 0.0, + softmax_scale=self.head_dim**-0.5, + causal=True, + ).view(hidden_shape) + + attn_output = self.c_proj.forward(attn_output) + + return attn_output, key_value + + def decode( + self, + hidden_states: torch.Tensor, + layer_past: torch.Tensor, + attention_mask: torch.Tensor, + batch_size:int, + key_length:int, + ) -> Tuple[torch.Tensor, Any]: + query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) + + # Calculate dimensions and recover layer_past + padded_key_length = attention_mask.size(-1) + allocated_key_length=layer_past.size(-2) + + # TODO: Allow pre-allocation with size > padded_key_length + if padded_key_length > allocated_key_length: + # Re-allocate kv cache and copy last value + allocated_kv_cache = torch.empty( + [batch_size, padded_key_length, 2 * self.head_dim], + dtype=key_value.dtype, + device=key_value.device, + ) + allocated_kv_cache[:, :key_length-1].copy_(layer_past[:, :key_length-1]) + # Nans in `value` can propagate through the matrix multiplication, + # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) + allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_() + layer_past=allocated_kv_cache + + # Copy the new values. + layer_past[:, key_length-1:key_length].copy_(key_value) + + key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1) + + # TODO: Upcasting needed for bf16? + upcast = query.dtype != torch.float32 + unscale = self.layer_idx + 1 if upcast else 1 + scale_factor = unscale ** -1 / self.head_dim ** 0.5 + + hidden_states = torch.baddbmm( + torch.empty((batch_size, self.num_heads, padded_key_length), device=query.device, dtype=query.dtype), + query.view(batch_size, self.num_heads, self.head_dim), + key.transpose(-1, -2), + beta=0, + alpha=scale_factor + ).unsqueeze_(1) + + if self.mask_value is None or self.mask_value.device != hidden_states.device: + self.mask_value = torch.full([], torch.finfo(torch.float32).min, dtype=torch.float32, device=hidden_states.device) + + if upcast: + hidden_states = upcast_masked_softmax(hidden_states, attention_mask, self.mask_value, unscale) + else: + hidden_states = masked_softmax(hidden_states, attention_mask, self.mask_value) + + hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape) + + hidden_states = self.c_proj.forward(hidden_states) + + return hidden_states, layer_past + +class GPTBigCodeMLP(nn.Module): + # TODO: Merge into GPTBigCodeBlock (needs renaming in state dict) + def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype): + super().__init__() + embed_dim = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim + self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta") + self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta") + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype): + super().__init__() + self.ln_1 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta") + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype) + self.ln_2 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta") + self.mlp = GPTBigCodeMLP(config, dtype=dtype) + + def prefill( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + sequence_lengths, + key_length:int, + ) -> Tuple[torch.Tensor, torch.Tensor, Any]: + hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual) + hidden_states, present = self.attn.prefill( + hidden_states, + sequence_lengths=sequence_lengths, + key_length=key_length, + ) + hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) + hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")) + return hidden_states, residual, present + + def decode( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + layer_past: torch.Tensor, + attention_mask: torch.Tensor, + batch_size:int, + key_length:int, + ) -> Tuple[torch.Tensor, torch.Tensor, Any]: + hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual) + hidden_states, present = self.attn.decode( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + batch_size=batch_size, + key_length=key_length, + ) + hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) + hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")) + return hidden_states, residual, present + + +class GPTBigCodePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTBigCodeConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = False + _no_split_modules = ["GPTBigCodeBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, GPTBigCodeModel): + module.bias.fill_(True).tril_() + elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)): + if isinstance(module, GPTBigCodeAttention): + module.mask_value.fill_(torch.finfo(torch.float32).min) + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + module.c_proj.weight.data.normal_( + mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + ) + module.c_proj._is_hf_initialized = True + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class GPTBigCodeModel(GPTBigCodePreTrainedModel): + # TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict) + def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype): + super().__init__(config) + + self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device="meta") + self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device="meta") + + self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) for i in range(config.num_hidden_layers)]) + self.ln_f = FastLayerNorm(config.hidden_size, dtype=dtype, device="meta", eps=config.layer_norm_epsilon) + + # Causal mask + self.register_buffer( + "causal_mask", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device="meta") + ) + +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): + pad_key_length_to_multiple=8 + + def __init__(self, config, dtype:torch.dtype, device:torch.device=torch.device("cuda")): + super().__init__(config) + if device.type!="cuda": + raise NotImplementedError(f"Device {device} not supported") + + self.transformer = GPTBigCodeModel(config, dtype=dtype) + self.lm_head = FastLinearNoBias(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta") + + self.to_empty(device=device) + + self._apply=self._apply_not_allowed + + # Initialize weights and apply final processing + self.post_init() + + def _apply_not_allowed(self): + # Dtype or device conversion would break the model. + raise NotImplementedError("Device or dtype conversion not supported!") + + def prefill( + self, + *, + input_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + position_ids: torch.Tensor, + predict_all_tokens: bool=True, + ) -> Tuple: + batch_size, query_length = input_ids.shape + + hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids) + + # Prefill (flash attn) + # TODO: Unpad earlier (input ids)? + hidden_states, padding_index, sequence_lengths, key_length = unpad_input(hidden_states, attention_mask) + assert key_length==query_length + + residual = None + past_key_values = [] + block:GPTBigCodeBlock + for block in self.transformer.h: + hidden_states, residual, key_value = block.prefill( + hidden_states, + residual=residual, + sequence_lengths=sequence_lengths, + key_length=query_length, + ) + past_key_values.append(pad_input(key_value, padding_index, batch_size, query_length)) + + hidden_states = self.transformer.ln_f.forward(hidden_states, residual) + + # Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible. + del residual + + if predict_all_tokens: + hidden_states = self.lm_head.forward(hidden_states) + hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) + else: + # TODO: Index directly instead + hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)[:, -1] + hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) + + return hidden_states, past_key_values + + def decode( + self, + *, + input_ids: torch.Tensor, + past_key_values: List[torch.Tensor], + attention_mask: [torch.Tensor], + position_ids: torch.Tensor, + key_length:int, + ) -> Tuple: + + batch_size, query_length = input_ids.shape + assert query_length == 1 + + hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids) + + # Standardize shape to (batch_size, hidden_size) + hidden_states.squeeze_(1) + + # Self-attention mask (padding + causal). + # TODO: Avoid unsqueeze + attention_mask = self.transformer.causal_mask[None, key_length - 1: key_length, + :key_length] * attention_mask.unsqueeze(1) + attention_mask.unsqueeze_(2) + + residual = None + block:GPTBigCodeBlock + for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): + hidden_states, residual, past_key_values[i] = block.decode( + hidden_states, + residual=residual, + layer_past=layer_past, + attention_mask=attention_mask, + batch_size=batch_size, + key_length=key_length, + ) + + hidden_states = self.transformer.ln_f.forward(hidden_states, residual) + hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) + + return hidden_states, past_key_values diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py new file mode 100644 index 00000000..63350843 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py @@ -0,0 +1,520 @@ +# coding=utf-8 +# Copyright 2023 The Bigcode team and HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTBigCode model.""" +import math +from typing import Optional, Tuple, Any, Union, List +from enum import IntEnum + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( + GPTBigCodeConfig, +) + +class InferenceRunnerType(IntEnum): + NO_RUNNER = 0 + # Use the inference runner without cuda graphs. + BASE_RUNNER = 1 + # Use cuda graphs in the inference runner. Leave out the attention which has a variable shape. + # This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models. + PARTIAL_GRAPH = 2 + # Turn the whole model into a cuda graph. One graph for each sequence length. + # Note: only useful for small batches and models, graphs take some time to generate, flaky. + # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. + FULL_GRAPH = 3 + +try: + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None + pad_input = None + unpad_input = None + + +logger = logging.get_logger(__name__) + + +@torch.jit.script +def upcast_masked_softmax( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype +): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1) + return x + + +@torch.profiler.record_function("softmax_function") +def softmax_function( + x: torch.Tensor, + mask: torch.Tensor, + mask_value: torch.Tensor, + scale: float, + softmax_dtype: torch.dtype, + upcast: bool = True, +): + """ + This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case + needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting + and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely + inefficient. + """ + #assert x.size(-1) % 8 == 0 + if upcast: + if mask is None: + return upcast_softmax(x, scale, softmax_dtype) + else: + return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype) + else: + if mask is None: + return torch.nn.functional.softmax(x, dim=-1) + else: + return masked_softmax(x, mask, mask_value) + + +class GPTBigCodeAttention(nn.Module): + def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + super().__init__() + self.mask_value = None + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.layer_idx = layer_idx + + # KV caching and padding + + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device) + self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device) + + @torch.profiler.record_function("GPTBigCodeAttention._get_mask_value") + def _get_mask_value(self, device, dtype): + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value + + @torch.profiler.record_function("GPTBigCodeAttention._attn") + def _attn(self, query, key, value, attention_mask): + softmax_dtype = torch.float32 + upcast = query.dtype != softmax_dtype + + unscale = self.layer_idx + 1 if upcast else 1 + scale_factor = unscale**-1 / self.head_dim**0.5 + + # (batch_size, query_length, num_heads * head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key_length = key.size(-2) + + key = key.transpose(-1, -2) + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + + attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) + if query.device.type == "cpu": + # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. + # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, + # but the fix has not been released as of pytorch version 2.0.0. + attn_weights.zero_() + beta = 1 + else: + beta = 0 + attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + + attn_weights = softmax_function( + attn_weights, + attention_mask, + None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), + unscale, + softmax_dtype, + upcast, + ) + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + + return attn_output + + @torch.profiler.record_function("GPTBigCodeAttention._attn_flash") + def _attn_flash(self, query, key, value, flash_params): + + query_shape = query.shape + attn_shape = query_shape[0], self.num_heads, self.head_dim + query = query.view(attn_shape) + key = key.unsqueeze(1).expand(attn_shape) + value = value.unsqueeze(1).expand(attn_shape) + + sequence_lengths, padding_index, _, max_sequence_length = flash_params + + # attn_output: (sum_seq_len, num_heads * head_dim) + attn_output = flash_attn_unpadded_func( + query, + key, + value, + sequence_lengths, + sequence_lengths, + max_sequence_length, + max_sequence_length, + 0.0, + softmax_scale=self.head_dim**-0.5, + causal=True, + ).view(query_shape) + + return attn_output + + @torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches") + def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): + + # Convert to standard KV cache format. + if flash_params is not None: + _, padding_index, batch_size, max_sequence_length = flash_params + current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) + return key_value, (current_kv_cache, max_sequence_length) + + current_kv_cache = key_value + + # Calculate dimensions and recover layer_past + batch_size = current_kv_cache.size(0) + query_length = current_kv_cache.size(-2) + if layer_past is None: + allocated_kv_cache, last_key_length = None, 0 + last_kv_cache = None + key_length = query_length + allocated_key_length = key_length + else: + allocated_kv_cache, last_key_length = layer_past + last_kv_cache = allocated_kv_cache[:, :last_key_length] + key_length = query_length + last_key_length + allocated_key_length = allocated_kv_cache.size(-2) + + padded_key_length = attention_mask.size(-1) + + # Re-allocate kv cache and copy last value + if padded_key_length > allocated_key_length: + allocated_kv_cache = torch.empty( + [batch_size, padded_key_length, 2 * self.head_dim], + dtype=current_kv_cache.dtype, + device=current_kv_cache.device, + ) + if layer_past is not None: + allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) + if padded_key_length > key_length: + # Nans in `value` can propagate through the matrix multiplication, + # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) + allocated_kv_cache[:, key_length:, self.head_dim :].zero_() + + # Copy the new values. + if padded_key_length > allocated_key_length or layer_past is not None: + allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) + padded_kv_cache = allocated_kv_cache[:, :padded_key_length] + # Use the merged KV cache. + # Not needed when layer_past is None but frees some memory. + key_value = padded_kv_cache + + if allocated_kv_cache is None: + allocated_kv_cache = current_kv_cache + present = allocated_kv_cache, key_length + return key_value, present + + @torch.profiler.record_function("GPTBigCodeAttention.forward") + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + flash_params: Optional[Tuple] = None + ) -> Tuple[torch.Tensor, Any]: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) + + # present = (allocated_kv_cache, key_length) + key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params) + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + if flash_params is None: + attn_output=self._attn(query, key, value, attention_mask) + else: + attn_output=self._attn_flash(query, key, value, flash_params) + + attn_output = self.c_proj(attn_output) + + return attn_output, present + + +class GPTBigCodeMLP(nn.Module): + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim + self.c_fc = nn.Linear(embed_dim, inner_dim) + self.c_proj = nn.Linear(inner_dim, embed_dim) + + @torch.profiler.record_function("GPTBigCodeMLP.forward") + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward + def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: + return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")) + + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + super().__init__() + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) + self.mlp = GPTBigCodeMLP(config) + + @torch.profiler.record_function("GPTBigCodeBlock.forward") + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + flash_params: Optional[Tuple] = None + ) -> Tuple[torch.Tensor, Any]: + with torch.profiler.record_function("GPTBigCodeAttention.ln"): + ai=self.ln_1(hidden_states) + attn_output, present = self.attn( + ai, + layer_past=layer_past, + attention_mask=attention_mask, + flash_params=flash_params, + ) + with torch.profiler.record_function("GPTBigCodeAttention.residual"): + hidden_states.add_(attn_output) + + with torch.profiler.record_function("GPTBigCodeAttention.dummy"): + pass + with torch.profiler.record_function("GPTBigCodeAttention.ln"): + ai=self.ln_2(hidden_states) + ai=self.mlp(ai) + with torch.profiler.record_function("GPTBigCodeAttention.residual"): + hidden_states.add_(ai) + return hidden_states, present + + + +class GPTBigCodePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTBigCodeConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = False + _no_split_modules = ["GPTBigCodeBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, GPTBigCodeModel): + module.bias.fill_(True).tril_() + elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)): + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + module.c_proj.weight.data.normal_( + mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + ) + module.c_proj._is_hf_initialized = True + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class GPTBigCodeModel(GPTBigCodePreTrainedModel): + def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + super().__init__(config) + + self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device) + self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device) + + self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon) + + self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner) + + self.flash_attention = True #config.flash_attention + + if self.flash_attention: + if flash_attn_unpadded_func is None: + raise RuntimeError( + "Flash Attention requires `flash_attn` and `einops`. " + "To install, run `pip install flash-attn einops`." + ) + + if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: + self.inference_runner = None + else: + from .inference_runner import GPTBigCodeInferenceRunner + + self.inference_runner = GPTBigCodeInferenceRunner(config, self) + + # Causal mask + self.register_buffer( + "bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device) + ) + + #@torch.profiler.record_function("GPTBigCodeModel._get_causal_mask") + def _get_causal_mask(self, padding_mask, query_length, key_length): + # Self-attention mask. + attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + + if padding_mask is not None: + attention_mask = attention_mask * padding_mask.unsqueeze(1).to( + dtype=torch.bool, device=attention_mask.device + ) + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + + # (batch_size, query_length, n_heads, key_length) + return attention_mask.unsqueeze(2) + + #@torch.profiler.record_function("GPTBigCodeModel.forward") + def forward( + self, + *, + input_ids: torch.Tensor, + past_key_values: Optional[Union[List[torch.Tensor], int]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: torch.Tensor, + ) -> Tuple: + if self.inference_runner is not None and past_key_values is not None: + if self.config.validate_runner_input: + assert past_key_values is not None + return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) + + batch_size, query_length = input_ids.shape + + flash_attention = self.flash_attention and past_key_values is None + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][1] + + + hidden_states = self.wte(input_ids) + self.wpe(position_ids) + + # TODO: Unpad earlier (input ids), support unpadded input? + if flash_attention: + hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( + hidden_states, attention_mask + ) + flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length) + attention_mask=None + else: + key_length = past_length + query_length + # Self-attention mask (padding + causal). + attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) + flash_params=None + + presents = [] + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + flash_params=flash_params + ) + + hidden_states = outputs[0] + presents.append(outputs[1]) + + hidden_states = self.ln_f(hidden_states) + + if flash_attention: + hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) + + return hidden_states, presents + + +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): + def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + super().__init__(config) + meta=torch.device("meta") + self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta) + + self.to_empty(device=device) + + # Initialize weights and apply final processing + self.post_init() + + #@torch.profiler.record_function("GPTBigCodeForCausalLM.forward") + def forward( + self, + *, + input_ids: torch.Tensor, + past_key_values: Optional[Union[List[torch.Tensor], int]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: torch.Tensor, + predict_all_tokens: bool=True, + ) -> Tuple: + + hidden_states, presents=self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids + ) + #with torch.profiler.record_function("GPTBigCodeForCausalLM.head"): + if not predict_all_tokens: + # We only care about the last token. + hidden_states = hidden_states[:, -1:] + + lm_logits = self.lm_head(hidden_states) + + return lm_logits, presents diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py index 078a8202..7ea367d7 100644 --- a/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py @@ -13,25 +13,34 @@ # limitations under the License. """PyTorch GPTBigCode model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Any, Union, List +from enum import IntEnum import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, -) +import dropout_layer_norm + from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( GPTBigCodeConfig, - InferenceRunnerType, ) +class InferenceRunnerType(IntEnum): + NO_RUNNER = 0 + # Use the inference runner without cuda graphs. + BASE_RUNNER = 1 + # Use cuda graphs in the inference runner. Leave out the attention which has a variable shape. + # This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models. + PARTIAL_GRAPH = 2 + # Turn the whole model into a cuda graph. One graph for each sequence length. + # Note: only useful for small batches and models, graphs take some time to generate, flaky. + # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. + FULL_GRAPH = 3 + + try: from flash_attn.bert_padding import pad_input, unpad_input @@ -44,6 +53,8 @@ except ImportError: logger = logging.get_logger(__name__) + +@torch.jit.script def upcast_masked_softmax( x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype ): @@ -55,12 +66,6 @@ def upcast_masked_softmax( @torch.jit.script -def upcast_masked_softmax_fused( - x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype -): - return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype) - - def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): input_dtype = x.dtype x = x.to(softmax_dtype) * scale @@ -69,21 +74,12 @@ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): @torch.jit.script -def upcast_softmax_fused(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): - return upcast_softmax(x, scale, softmax_dtype) - - def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): x = torch.where(mask, x, mask_value) x = torch.nn.functional.softmax(x, dim=-1) return x -@torch.jit.script -def masked_softmax_fused(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): - return masked_softmax(x, mask, mask_value) - - def softmax_function( x: torch.Tensor, mask: torch.Tensor, @@ -91,78 +87,39 @@ def softmax_function( scale: float, softmax_dtype: torch.dtype, upcast: bool = True, - fused_softmax: Optional[bool] = None, ): """ This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely - inefficient. TODO: Could have better fused kernels depending on scaling, dropout and head mask. - Is it doable without writing 32 functions? + inefficient. """ - if fused_softmax is None: - fused_softmax = x.size(-1) % 8 == 0 + #assert x.size(-1) % 8 == 0 if upcast: if mask is None: - return (upcast_softmax_fused if fused_softmax else upcast_softmax)(x, scale, softmax_dtype) + return upcast_softmax(x, scale, softmax_dtype) else: - return (upcast_masked_softmax_fused if fused_softmax else upcast_masked_softmax)( - x, mask, mask_value, scale, softmax_dtype - ) + return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype) else: if mask is None: return torch.nn.functional.softmax(x, dim=-1) else: - return (masked_softmax_fused if fused_softmax else masked_softmax)(x, mask, mask_value) + return masked_softmax(x, mask, mask_value) class GPTBigCodeAttention(nn.Module): - def __init__(self, config, is_cross_attention=False, layer_idx=None): + def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): super().__init__() self.mask_value = None - assert config.multi_query - assert config.attention_softmax_in_fp32 - assert config.scale_attention_softmax_in_fp32 - - self.flash_attention = config.flash_attention self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.split_size = self.embed_dim - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - - self.scale_attn_weights = config.scale_attn_weights - self.is_cross_attention = is_cross_attention - self.layer_idx = layer_idx - self.fused_softmax = config.fused_softmax # KV caching and padding - self.pre_allocate_kv_cache = ( - config.n_embd if config.pre_allocate_kv_cache is True else config.pre_allocate_kv_cache - ) - pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length - self._tuple_cache_format = self.pre_allocate_kv_cache or pad_key_length or self.flash_attention - if self.is_cross_attention: - raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim) - - self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) - - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - - if self.flash_attention: - if flash_attn_unpadded_func is None: - raise RuntimeError( - "Flash Attention requires `flash_attn` and `einops`. " - "To install, run `pip install flash-attn einops`." - ) + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device) + self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device) def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. @@ -171,14 +128,11 @@ class GPTBigCodeAttention(nn.Module): return self.mask_value def _attn(self, query, key, value, attention_mask): - dtype = query.dtype softmax_dtype = torch.float32 - upcast = dtype != softmax_dtype + upcast = query.dtype != softmax_dtype unscale = self.layer_idx + 1 if upcast else 1 - scale_factor = unscale**-1 - if self.scale_attn_weights: - scale_factor /= self.head_dim**0.5 + scale_factor = unscale**-1 / self.head_dim**0.5 # (batch_size, query_length, num_heads * head_dim) query_shape = query.shape @@ -212,16 +166,12 @@ class GPTBigCodeAttention(nn.Module): unscale, softmax_dtype, upcast, - self.fused_softmax, ) - - attn_weights = self.attn_dropout(attn_weights) - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) - return attn_output, attn_weights + return attn_output - def _attn_flash(self, query, key, value, attention_mask): + def _attn_flash(self, query, key, value, flash_params): query_shape = query.shape attn_shape = query_shape[0], self.num_heads, self.head_dim @@ -229,7 +179,7 @@ class GPTBigCodeAttention(nn.Module): key = key.unsqueeze(1).expand(attn_shape) value = value.unsqueeze(1).expand(attn_shape) - sequence_lengths, padding_index, _, max_sequence_length = attention_mask + sequence_lengths, padding_index, _, max_sequence_length = flash_params # attn_output: (sum_seq_len, num_heads * head_dim) attn_output = flash_attn_unpadded_func( @@ -240,32 +190,22 @@ class GPTBigCodeAttention(nn.Module): sequence_lengths, max_sequence_length, max_sequence_length, - self.dropout_p if self.training else 0.0, - softmax_scale=self.head_dim**-0.5 if self.scale_attn_weights else 1, + 0.0, + softmax_scale=self.head_dim**-0.5, causal=True, ).view(query_shape) - return attn_output, None + return attn_output - def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length): - batch_size = kv_cache.size(-1) - assert not self.training - allocated_kv_cache = torch.empty( - [batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device - ) - allocated_kv_cache[:, :key_length].copy_(kv_cache) - padded_kv_cache = allocated_kv_cache[:, :padded_key_length] - return allocated_kv_cache, padded_kv_cache - - def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): - flash_attention = self.flash_attention and layer_past is None + def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): # Convert to standard KV cache format. - if flash_attention and use_cache: - _, padding_index, batch_size, max_sequence_length = attention_mask + if flash_params is not None: + _, padding_index, batch_size, max_sequence_length = flash_params current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) - else: - current_kv_cache = key_value + return key_value, (current_kv_cache, max_sequence_length) + + current_kv_cache = key_value # Calculate dimensions and recover layer_past batch_size = current_kv_cache.size(0) @@ -281,38 +221,33 @@ class GPTBigCodeAttention(nn.Module): key_length = query_length + last_key_length allocated_key_length = allocated_kv_cache.size(-2) - padded_key_length = key_length if flash_attention else attention_mask.size(-1) - allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length) + padded_key_length = attention_mask.size(-1) # Re-allocate kv cache and copy last value - if allocate_key_length > allocated_key_length: + if padded_key_length > allocated_key_length: allocated_kv_cache = torch.empty( - [batch_size, allocate_key_length, 2 * self.head_dim], + [batch_size, padded_key_length, 2 * self.head_dim], dtype=current_kv_cache.dtype, device=current_kv_cache.device, ) if layer_past is not None: allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) - if allocate_key_length > key_length: + if padded_key_length > key_length: # Nans in `value` can propagate through the matrix multiplication, # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) allocated_kv_cache[:, key_length:, self.head_dim :].zero_() # Copy the new values. - if allocate_key_length > allocated_key_length or layer_past is not None: + if padded_key_length > allocated_key_length or layer_past is not None: allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) padded_kv_cache = allocated_kv_cache[:, :padded_key_length] - if not flash_attention: - # Use the merged KV cache. - # Not needed when layer_past is None but frees some memory. - key_value = padded_kv_cache + # Use the merged KV cache. + # Not needed when layer_past is None but frees some memory. + key_value = padded_kv_cache - if use_cache: - if allocated_kv_cache is None: - allocated_kv_cache = current_kv_cache - present = allocated_kv_cache, key_length - else: - present = None + if allocated_kv_cache is None: + allocated_kv_cache = current_kv_cache + present = allocated_kv_cache, key_length return key_value, present def forward( @@ -320,118 +255,63 @@ class GPTBigCodeAttention(nn.Module): hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[ - Tuple[torch.Tensor, Optional[torch.Tensor]], - Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], - ]: - flash_attention = self.flash_attention and layer_past is None + flash_params: Optional[Tuple] = None + ) -> Tuple[torch.Tensor, Any]: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) - if self._tuple_cache_format: - # present = (allocated_kv_cache, key_length) - key_value, present = self._merge_kv_caches(key_value, use_cache, layer_past, attention_mask) - else: - # present = key_value - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + # present = (allocated_kv_cache, key_length) + key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - attn_output, attn_weights = (self._attn_flash if flash_attention else self._attn)( - query, key, value, attention_mask - ) + if flash_params is None: + attn_output=self._attn(query, key, value, attention_mask) + else: + attn_output=self._attn_flash(query, key, value, flash_params) attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - if flash_attention: - raise ValueError("`output_attentions` is not supported with Flash Attention.") - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, present class GPTBigCodeMLP(nn.Module): - def __init__(self, intermediate_size, config): + def __init__(self, config): super().__init__() embed_dim = config.hidden_size - self.c_fc = nn.Linear(embed_dim, intermediate_size) - self.c_proj = nn.Linear(intermediate_size, embed_dim) - self.act = ACT2FN[config.activation_function] - self.dropout = nn.Dropout(config.resid_pdrop) + inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim + self.c_fc = nn.Linear(embed_dim, inner_dim) + self.c_proj = nn.Linear(inner_dim, embed_dim) # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: - hidden_states = self.c_fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) - return hidden_states + return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")) class GPTBigCodeBlock(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): super().__init__() - hidden_size = config.hidden_size - self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size - - self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) - self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - - if config.add_cross_attention: - raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") - - self.mlp = GPTBigCodeMLP(self.inner_dim, config) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) + self.mlp = GPTBigCodeMLP(config) def forward( self, - hidden_states: Optional[Tuple[torch.Tensor]], + hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[ - Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] - ]: - if encoder_hidden_states is not None or encoder_attention_mask is not None: - raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") - - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states, + flash_params: Optional[Tuple] = None + ) -> Tuple[torch.Tensor, Any]: + attn_output, present = self.attn( + self.ln_1(hidden_states), layer_past=layer_past, attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + flash_params=flash_params, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] - # residual connection - hidden_states = attn_output + residual + hidden_states.add_(attn_output) + hidden_states.add_(self.mlp(self.ln_2(hidden_states))) + return hidden_states, present - residual = hidden_states - hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + feed_forward_hidden_states - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions, cross_attentions) class GPTBigCodePreTrainedModel(PreTrainedModel): @@ -442,7 +322,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): config_class = GPTBigCodeConfig base_model_prefix = "transformer" - supports_gradient_checkpointing = True + supports_gradient_checkpointing = False _no_split_modules = ["GPTBigCodeBlock"] def __init__(self, *inputs, **kwargs): @@ -450,7 +330,9 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): + if isinstance(module, GPTBigCodeModel): + module.bias.fill_(True).tril_() + elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)): # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. @@ -475,35 +357,27 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GPTBigCodeModel): - module.gradient_checkpointing = value - class GPTBigCodeModel(GPTBigCodePreTrainedModel): - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] - - def __init__(self, config): + def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): super().__init__(config) - assert config.multi_query - self.embed_dim = config.hidden_size - if config.add_cross_attention: - raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") + self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device) + self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device) - self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon) - self.drop = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) - self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner) - self.pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length - self._tuple_cache_format = config.pre_allocate_kv_cache or self.pad_key_length or config.flash_attention - self.inference_runner_type = InferenceRunnerType(config.inference_runner) + self.flash_attention = True #config.flash_attention - self.flash_attention = config.flash_attention + if self.flash_attention: + if flash_attn_unpadded_func is None: + raise RuntimeError( + "Flash Attention requires `flash_attn` and `einops`. " + "To install, run `pip install flash-attn einops`." + ) if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: self.inference_runner = None @@ -512,23 +386,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): self.inference_runner = GPTBigCodeInferenceRunner(config, self) - max_positions = config.max_position_embeddings # Causal mask self.register_buffer( - "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False + "bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device) ) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - def _get_causal_mask(self, padding_mask, query_length, key_length): # Self-attention mask. attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] @@ -537,306 +399,105 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): attention_mask = attention_mask * padding_mask.unsqueeze(1).to( dtype=torch.bool, device=attention_mask.device ) + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) + # (batch_size, query_length, n_heads, key_length) return attention_mask.unsqueeze(2) - def _get_position_ids(self, position_ids, padding_mask, query_length, key_length, device): - if position_ids is not None: - position_ids = position_ids.to(device) - elif padding_mask is not None and padding_mask.ndim == 2: - # create position_ids on the fly for batch generation - position_ids = padding_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(padding_mask == 0, 1) - if key_length > query_length: - position_ids = position_ids[:, key_length - query_length : key_length :] - else: - position_ids = torch.arange(key_length - query_length, key_length, dtype=torch.long, device=device) - return position_ids.view(-1, query_length) - def forward( self, - input_ids: Optional[torch.Tensor] = None, + *, + input_ids: torch.Tensor, past_key_values: Optional[Union[List[torch.Tensor], int]] = None, attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + position_ids: torch.Tensor, + ) -> Tuple: if self.inference_runner is not None and past_key_values is not None: if self.config.validate_runner_input: - assert input_ids is not None assert past_key_values is not None - assert attention_mask is not None - assert token_type_ids is None - assert position_ids is not None - assert inputs_embeds is None - assert encoder_hidden_states is None - assert encoder_attention_mask is None - use_cache = use_cache if use_cache is not None else self.config.use_cache - assert use_cache is True - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - assert output_attentions is False - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - assert output_hidden_states is False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - assert return_dict is True return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = self.config.use_cache if use_cache is None else use_cache - return_dict = self.config.use_return_dict if return_dict is None else return_dict - - if input_ids is not None: - if inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - input_shape = input_ids.shape - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size, query_length = input_ids.shape - elif inputs_embeds is not None: - input_shape = inputs_embeds.shape[:-1] - inputs_embeds = inputs_embeds.view(-1, input_shape[-2:]) - batch_size, query_length = inputs_embeds.shape[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") + batch_size, query_length = input_ids.shape flash_attention = self.flash_attention and past_key_values is None if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) - elif self._tuple_cache_format: - past_length = past_key_values[0][1] else: - past_length = past_key_values[0].size(-2) - key_length = past_length + query_length + past_length = past_key_values[0][1] - position_ids = self._get_position_ids(position_ids, attention_mask, query_length, key_length, input_ids.device) - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, query_length) - - if not flash_attention: - # Self-attention mask (padding + causal). - attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) - if self.pad_key_length: - pad = -key_length % 8 - if pad > 0: - attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) - - if encoder_hidden_states is not None or encoder_attention_mask is not None: - raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) + hidden_states = self.wte(input_ids) + self.wpe(position_ids) # TODO: Unpad earlier (input ids), support unpadded input? if flash_attention: hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( hidden_states, attention_mask ) - # Pass the required parameters through the attention_mask argument - attention_mask = (sequence_lengths, padding_index, batch_size, max_sequence_length) + flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length) + attention_mask=None + else: + key_length = past_length + query_length + # Self-attention mask (padding + causal). + attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) + flash_params=None - presents = [] if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None + presents = [] for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + flash_params=flash_params + ) hidden_states = outputs[0] - if use_cache: - presents.append(outputs[1]) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + presents.append(outputs[1]) hidden_states = self.ln_f(hidden_states) if flash_attention: hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) - hidden_states = hidden_states.view(input_shape + (hidden_states.size(-1),)) - - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) + return hidden_states, presents class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): - _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] - - def __init__(self, config): + def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): super().__init__(config) - self.transformer = GPTBigCodeModel(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.predict_last_token = config.predict_last_token + meta=torch.device("meta") + self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta) + + self.to_empty(device=device) # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": kwargs.get("position_ids", None), - "attention_mask": kwargs.get("attention_mask", None), - "token_type_ids": token_type_ids, - } - ) - return model_inputs - def forward( self, - input_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + *, + input_ids: torch.Tensor, + past_key_values: Optional[Union[List[torch.Tensor], int]] = None, attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + position_ids: torch.Tensor, + predict_all_tokens: bool=True, + ) -> Tuple: - transformer_outputs = self.transformer( - input_ids, + hidden_states, presents=self.transformer( + input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + position_ids=position_ids ) - hidden_states = transformer_outputs[0] - if self.predict_last_token and not self.training: + if not predict_all_tokens: # We only care about the last token. hidden_states = hidden_states[:, -1:] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) + return lm_logits, presents diff --git a/server/text_generation_server/models/custom_modeling/inference_runner.py b/server/text_generation_server/models/custom_modeling/inference_runner.py index 9e1fb5ea..19d80b2d 100644 --- a/server/text_generation_server/models/custom_modeling/inference_runner.py +++ b/server/text_generation_server/models/custom_modeling/inference_runner.py @@ -1,13 +1,12 @@ -from typing import List, Union +from typing import List, Union, Tuple import torch from transformers import GPTBigCodeConfig -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( InferenceRunnerType, ) -from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, softmax_function +from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeBlock, softmax_function def _align_tensor(x): @@ -291,7 +290,7 @@ class GPTBigCodeInferenceRunner: attention_mask: torch.Tensor, position_ids: torch.Tensor, past_key_values: Union[List[torch.Tensor], int], - ) -> BaseModelOutputWithPastAndCrossAttentions: + ) -> Tuple: batch_size, query_length = input_ids.shape assert query_length == 1 if self.batch_size is None: @@ -333,7 +332,4 @@ class GPTBigCodeInferenceRunner: else: hidden_states = self._forward(key_length) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=key_length, - ) + return hidden_states, key_length diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index aee0480d..55162186 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -449,6 +449,49 @@ class FlashCausalLM(Model): pre_allocate_past_size=pre_allocate_past_size, ) + def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, use_cache: Optional[torch.dtype]): + diff = max_input_length - max(batch.input_lengths) + for i in range(len(batch)): + batch.input_lengths[i] += diff + batch.prefix_offsets[i] = 0 + batch.read_offsets[i] = 0 + batch.all_input_ids[i] = batch.all_input_ids[i] + [self.tokenizer.pad_token_id] * diff if diff >= 0 else batch.all_input_ids[i][:diff] + # TODO: Bug!?! + batch.stopping_criterias[i].current_tokens += diff + + if use_cache: + assert len(batch.all_input_ids_tensor)>0, "Must run prefill first" + batch.input_ids.fill_(self.tokenizer.pad_token_id) + batch.position_ids += diff + batch.cu_seqlens += diff * batch.cu_seqlens_q + # TODO: Bug!?! + batch.max_seqlen += batch.max_seqlen + diff*len(batch) + + for i in range(len(batch)): + batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(self.tokenizer.pad_token_id) + + batch.past_key_values = batch.past_key_values = torch.randn( + ( + batch.past_key_values.shape[0], + batch.past_key_values.shape[1] + len(batch.requests), + *batch.past_key_values.shape[2:], + ), device=batch.past_key_values.device, dtype= batch.past_key_values.dtype + ) + else: + batch.max_seqlen = max(batch.input_lengths) + batch.all_input_ids_tensor=[] + + batch.input_ids = torch.tensor( + np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device + ) + batch.position_ids = torch.tensor( + np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device + ) + batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32) + batch.past_key_values=None + + + @tracer.start_as_current_span("generate_token") def generate_token( self, batch: FlashCausalLMBatch diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 5dc31309..2812e043 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -223,6 +223,7 @@ class FlashSantacoderSharded(FlashSantacoder): world_size=world_size, ) + @staticmethod def load_weights( model, diff --git a/server/text_generation_server/models/gpt_bigcode.py b/server/text_generation_server/models/gpt_bigcode.py new file mode 100644 index 00000000..ac321786 --- /dev/null +++ b/server/text_generation_server/models/gpt_bigcode.py @@ -0,0 +1,135 @@ +import torch + +from dataclasses import dataclass + +from opentelemetry import trace +from transformers import AutoTokenizer +from typing import Optional, Type + +from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch +from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeForCausalLM + + +tracer = trace.get_tracer(__name__) + + +@dataclass +class BigcodeBatch(VectorizedCausalLMBatch): + kv_cache_seq_dim: int = 1 + + def _filter_kv_caches(self, keep_indices, sequence_slice): + if self.past_key_values is not None: + for layer_kv, _ in self.past_key_values: + # Update tensors in-place to allow incremental garbage collection + layer_kv.data = layer_kv[keep_indices, sequence_slice] + + @classmethod + def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): + device = batches[0].input_ids.device + batch_size = sum([len(batch.requests) for batch in batches]) + + for batch in batches: + if batch.past_key_values is None: + raise ValueError("Only concatenate prefilled batches") + + past_key_values = [] + for kv_caches in zip(*(batch.past_key_values for batch in batches)): + key_values, seq_lengths = zip(*kv_caches) + assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths)) + + allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values)) + allocate_seq_len += - allocate_seq_len % 8 + + kv_cache = torch.empty( + (batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device + ) + for key_value, start_index, end_index, left_index in zip( + key_values, + start_indices, + end_indices, + left_indices, + ): + kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value) + # Set padding to zero to avoid propagating nans. + kv_cache[start_index:end_index, :left_index].fill_(0) + kv_cache[start_index:end_index, max_input_length:].fill_(0) + past_key_values.append((kv_cache, max_input_length)) + + def __len__(self): + return len(self.requests) + +class BigcodeCausalLM(VectorizedCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left", truncation_side="left" + ) + model = GPTBigCodeForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto" if torch.cuda.is_available() else None, + load_in_8bit=quantize == "bitsandbytes", + ) + tokenizer.pad_token_id = ( + model.config.pad_token_id + if model.config.pad_token_id is not None + else model.config.eos_token_id + ) + + super(VectorizedCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[BigcodeBatch]: + return BigcodeBatch + + def forward(self, batch:BigcodeBatch): + key_length = batch.max_input_length + query_length = key_length if batch.past_key_values is None else 1 + input_ids = batch.input_ids[:, key_length - query_length : key_length] + # Model Forward + logits, batch.past_key_values = self.model.forward( + input_ids=input_ids, + attention_mask=batch.attention_mask[:, :key_length], + position_ids=batch.position_ids[:, key_length - query_length : key_length], + past_key_values=batch.past_key_values, + use_cache=True, + ) + next_token_ids, logprobs = batch.next_token_chooser( + input_ids, logits, batch.details + ) + # Update batch + # TODO: Why do we need all input ids? + batch.input_ids[:, key_length].copy_(next_token_ids) + batch.input_lengths = [length + 1 for length in batch.input_lengths] + batch.max_input_length += 1 + + return next_token_ids, logprobs + + def mock_kv_cache(self, batch: BigcodeBatch, dtype:Optional[torch.dtype]): + allocate_length=batch.max_input_length+-batch.max_input_length%8 + return [(torch.empty( + [len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], + dtype=dtype, + device=batch.input_ids.device, + ),batch.max_input_length-1) for _ in range(self.model.config.n_layer)] diff --git a/server/text_generation_server/models/gpt_bigcode2.py b/server/text_generation_server/models/gpt_bigcode2.py new file mode 100644 index 00000000..24875c3b --- /dev/null +++ b/server/text_generation_server/models/gpt_bigcode2.py @@ -0,0 +1,149 @@ +import torch + +from dataclasses import dataclass + +from opentelemetry import trace +from transformers import AutoTokenizer +from typing import Optional, Type + +from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch +from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import GPTBigCodeForCausalLM + +tracer = trace.get_tracer(__name__) + + +@dataclass +class Bigcode2Batch(VectorizedCausalLMBatch): + kv_cache_seq_dim: int = 1 + pad_key_length_to_multiple:int=8 + + # Prefill the attention mask for padded key length. + attention_mask_fill_value=False + + def _filter_kv_caches(self, keep_indices, sequence_slice): + if self.past_key_values is not None: + for layer_kv, _ in self.past_key_values: + # Update tensors in-place to allow incremental garbage collection + layer_kv.data = layer_kv[keep_indices, sequence_slice] + + @classmethod + def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): + device = batches[0].input_ids.device + batch_size = sum([len(batch.requests) for batch in batches]) + + for batch in batches: + if batch.past_key_values is None: + raise ValueError("Only concatenate prefilled batches") + + past_key_values = [] + for kv_caches in zip(*(batch.past_key_values for batch in batches)): + key_values, seq_lengths = zip(*kv_caches) + assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths)) + + allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values)) + allocate_seq_len += - allocate_seq_len % batches[0].pad_key_length_to_multiple + + kv_cache = torch.empty( + (batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device + ) + for key_value, start_index, end_index, left_index in zip( + key_values, + start_indices, + end_indices, + left_indices, + ): + kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value) + # Set padding to zero to avoid propagating nans. + kv_cache[start_index:end_index, :left_index].fill_(0) + kv_cache[start_index:end_index, max_input_length:].fill_(0) + past_key_values.append((kv_cache, max_input_length)) + + def __len__(self): + return len(self.requests) + +class Bigcode2CausalLM(VectorizedCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left", truncation_side="left" + ) + model = GPTBigCodeForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto" if torch.cuda.is_available() else None, + load_in_8bit=quantize == "bitsandbytes", + ) + tokenizer.pad_token_id = ( + model.config.pad_token_id + if model.config.pad_token_id is not None + else model.config.eos_token_id + ) + + super(VectorizedCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[Bigcode2Batch]: + return Bigcode2Batch + + def forward(self, batch:Bigcode2Batch): + key_length = batch.max_input_length + if batch.past_key_values is None: + # Prefill (flash attn, unpadded key length) + batch.pad_key_length_to_multiple=self.model.pad_key_length_to_multiple + padded_key_length=key_length + query_length=key_length + else: + # Decode (fused attn, padded key length) + batch.attention_mask[:, key_length-1].fill_(True) + padded_key_length=key_length+-key_length%batch.pad_key_length_to_multiple + query_length=1 + + input_ids = batch.input_ids[:, key_length - query_length : key_length] + # Model Forward + logits, batch.past_key_values = self.model.forward( + input_ids=input_ids, + attention_mask=batch.attention_mask[:, :padded_key_length], + position_ids=batch.position_ids[:, key_length - query_length : key_length], + past_key_values=batch.past_key_values, + key_length=key_length, + predict_all_tokens=batch.details + ) + next_token_ids, logprobs = batch.next_token_chooser( + input_ids, logits, batch.details + ) + # Update batch + # TODO: Why do we need all input ids? + batch.input_ids[:, key_length].copy_(next_token_ids) + batch.input_lengths = [length + 1 for length in batch.input_lengths] + batch.max_input_length += 1 + + return next_token_ids, logprobs + + def mock_kv_cache(self, batch: Bigcode2Batch, dtype:Optional[torch.dtype]): + allocate_length=batch.max_input_length+-batch.max_input_length%batch.pad_key_length_to_multiple + return [(torch.empty( + [len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], + dtype=dtype, + device=batch.input_ids.device, + ),batch.max_input_length-1) for _ in range(self.model.config.n_layer)] diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index b2194b14..e3a5b61b 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -58,6 +58,9 @@ class VectorizedCausalLMBatch(Batch): kv_cache_seq_dim: int = 2 + # Prefill the attention mask for the generated tokens + attention_mask_fill_value=True + # TODO: Get from requests (should these be lists?) details: bool = os.environ.get("RETURN_DETAILS") is not None generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None @@ -113,7 +116,7 @@ class VectorizedCausalLMBatch(Batch): attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"]) - attention_mask[:, max_input_length:].fill_(1) + attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value) position_ids = attention_mask.cumsum(-1).sub_(1) position_ids[:, :max_input_length].relu_() @@ -268,7 +271,7 @@ class VectorizedCausalLMBatch(Batch): # Allocate maximum attention_mask attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask[:, :max_input_length].fill_(0) - attention_mask[:, max_input_length:].fill_(1) + attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value) input_ids = torch.empty(input_shape, dtype=torch.int64, device=device) # TODO : only needed for prefill @@ -293,7 +296,7 @@ class VectorizedCausalLMBatch(Batch): ) kv_cache_seq_dim = batches[0].kv_cache_seq_dim - past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices) + past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices, max_input_length) return cls( batch_id=batches[0].batch_id, @@ -314,7 +317,7 @@ class VectorizedCausalLMBatch(Batch): ) @classmethod - def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices): + def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): device = batches[0].input_ids.device batch_size = sum([len(batch.requests) for batch in batches]) @@ -355,29 +358,23 @@ class VectorizedCausalLMBatch(Batch): else batch.past_key_values[i][j] for batch in batches ] - # Generally `max_input_length`, unless the model allocates more than needed. - right_indices = [ - left_index + tensor.size(kv_cache_seq_dim) - for tensor, left_index in zip(tensors_to_merge, left_indices) - ] combined_shape = [batch_size] + list(tensors_to_merge[0].shape[1:]) - combined_shape[kv_cache_seq_dim] = max(right_indices) + combined_shape[kv_cache_seq_dim] = max_input_length # Set to zero to avoid propagating nans in padded values. kv_cache = torch.zeros( combined_shape, dtype=tensors_to_merge[0].dtype, device=device ) - for tensor, start_index, end_index, left_index, right_index in zip( + for tensor, start_index, end_index, left_index in zip( tensors_to_merge, start_indices, end_indices, left_indices, - right_indices, ): kv_cache[ [ slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), - slice(left_index, right_index), + slice(left_index, max_input_length), ] ].copy_(tensor) if kv_format is None: @@ -398,13 +395,11 @@ class VectorizedCausalLM(Model): self, model_id: str, revision: Optional[str] = None, - quantize: bool = False, - decode_buffer: int = 3, + quantize: Optional[str] = None, ): if torch.cuda.is_available(): device = torch.device("cuda") - # TODO: Choose dtype (fp16?) - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") @@ -415,26 +410,26 @@ class VectorizedCausalLM(Model): tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - self.model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", trust_remote_code=True, - ).eval() + ) tokenizer.pad_token_id = ( - self.model.config.pad_token_id - if self.model.config.pad_token_id is not None - else self.model.config.eos_token_id + model.config.pad_token_id + if model.config.pad_token_id is not None + else model.config.eos_token_id ) super().__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, - decode_buffer=decode_buffer, ) @property @@ -446,6 +441,30 @@ class VectorizedCausalLM(Model): generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False ) + def forward(self, batch:VectorizedCausalLMBatch): + key_length = batch.max_input_length + query_length = key_length if batch.past_key_values is None else 1 + input_ids = batch.input_ids[:, key_length - query_length : key_length] + # Model Forward + logits, batch.past_key_values, *_ = self.model.forward( + input_ids=input_ids, + attention_mask=batch.attention_mask[:, :key_length], + position_ids=batch.position_ids[:, key_length - query_length : key_length], + past_key_values=batch.past_key_values, + return_dict=False, + use_cache=True, + ) + next_token_ids, logprobs = batch.next_token_chooser( + input_ids, logits, batch.details + ) + # Update batch + # TODO: Why do we need all input ids? + batch.input_ids[:, key_length].copy_(next_token_ids) + batch.input_lengths = [length + 1 for length in batch.input_lengths] + batch.max_input_length += 1 + + return next_token_ids, logprobs + @tracer.start_as_current_span("generate_token") def generate_token( self, batch: VectorizedCausalLMBatch @@ -453,20 +472,9 @@ class VectorizedCausalLM(Model): key_length = batch.max_input_length if key_length > batch.input_ids.size(1): raise RuntimeError("Cannot generate more than `max_tokens`.") + is_prefill = batch.past_key_values is None - query_length = key_length if batch.past_key_values is None else 1 - input_ids = batch.input_ids[:, key_length - query_length : key_length] - - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=batch.attention_mask[:, :key_length], - position_ids=batch.position_ids[:, key_length - query_length : key_length], - past_key_values=batch.past_key_values, - ) - # TODO: Post-processing - next_token_ids, logprobs = batch.next_token_chooser( - input_ids, outputs.logits, batch.details - ) + next_token_ids, logprobs = self.forward(batch) if batch.generate_stream: # TODO: self.decode_token, offsets? @@ -479,7 +487,6 @@ class VectorizedCausalLM(Model): .squeeze(1) .tolist() ) - is_prefill = batch.past_key_values is None if is_prefill: prefill_token_ids = batch.input_ids[:, :key_length].tolist() prefill_logprobs = ( @@ -491,7 +498,8 @@ class VectorizedCausalLM(Model): for prefill_token_ids_, prefill_logprobs_, input_length in zip( prefill_token_ids, prefill_logprobs, batch.input_lengths ): - prefill_token_ids_ = prefill_token_ids_[-input_length:] + # Input length has already been incremented so we subtract 1. + prefill_token_ids_ = prefill_token_ids_[-(input_length-1):] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids_, clean_up_tokenization_spaces=False, @@ -505,13 +513,6 @@ class VectorizedCausalLM(Model): ) ) - # Update batch - # TODO: Why do we need all input ids? - batch.input_ids[:, key_length].copy_(next_token_ids) - batch.past_key_values = outputs.past_key_values - batch.input_lengths = [length + 1 for length in batch.input_lengths] - batch.max_input_length += 1 - # TODO: Vectorize some of this? generations: List[Generation] = [] @@ -556,3 +557,24 @@ class VectorizedCausalLM(Model): generations.append(generation) return generations, next_batch + + def mock_kv_cache(self, batch: VectorizedCausalLMBatch, dtype:Optional[torch.dtype]): + from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM + if not isinstance(self.model, GPTBigCodeForCausalLM): + raise NotImplementedError() + return [torch.empty( + [len(batch), batch.max_input_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], + dtype=dtype, + device=batch.input_ids.device, + ) for _ in range(self.model.config.n_layer)] + + def fast_forward(self, batch: VectorizedCausalLMBatch, max_input_length: int, cache_dtype:Optional[torch.dtype]): + diff=max_input_length-batch.max_input_length + batch.input_ids[:, batch.max_input_length:max_input_length].fill_(self.tokenizer.pad_token_id) + batch.input_lengths = [length + diff for length in batch.input_lengths] + batch.max_input_length += diff + for stopping_criteria in batch.stopping_criterias: + stopping_criteria.current_tokens+=diff + batch.past_key_values = None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype) + +