From 7a94845ebaf05ee103cf064f4401d6029fb27235 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 May 2023 15:22:20 -0400 Subject: [PATCH] Fixes, cpu-optimized model, misc --- .../text_generation_server/models/__init__.py | 2 + .../models/causal_lm.py | 60 ++ .../custom_modeling/gpt_bigcode2_modeling.py | 31 +- .../custom_modeling/gpt_bigcode3_modeling.py | 866 ++++++++---------- .../custom_modeling/gpt_bigcode4_modeling.py | 565 ++++++++++++ .../models/flash_causal_lm.py | 20 +- .../models/gpt_bigcode2.py | 35 +- 7 files changed, 1068 insertions(+), 511 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/gpt_bigcode4_modeling.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 7fcb4c7c..8f48fdb7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -100,6 +100,8 @@ model_type_map={ "vector":VectorizedCausalLM, "bigcode":BigcodeCausalLM, "bigcode2":Bigcode2CausalLM, + "bigcode3":Bigcode2CausalLM, + "bigcode4":Bigcode2CausalLM, } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c0572a72..26b2687f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -508,6 +508,63 @@ class CausalLM(Model): ) return outputs.logits, outputs.past_key_values + + def fast_forward( + self, + batch: CausalLMBatch, + max_input_length: int, + cache_dtype: Optional[torch.dtype], + ): + diff = max_input_length - batch.max_input_length + batch.max_input_length+=diff + fill_value=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id + + for i in range(len(batch)): + batch.all_input_ids[i]=torch.nn.functional.pad(batch.all_input_ids[i], (0, 0, 0, diff), value=fill_value)[:batch.max_input_length] + batch.input_lengths[i]+=diff + batch.prefix_offsets[i] = 0 + batch.read_offsets[i] = 0 + + batch.attention_mask[:, -batch.padding_right_offset:batch.padding_right_offset+diff] = 1 + batch.padding_right_offset -= diff + + if cache_dtype is None: + assert batch.input_ids.shape==(len(batch),batch.max_input_length-diff) + batch.input_ids=torch.nn.functional.pad(batch.input_ids, (0,diff), value=fill_value)[:, :batch.max_input_length] + batch.position_ids = batch.attention_mask[:, :batch.max_input_length].long().cumsum(-1) - 1 + batch.past_key_values=None + else: + from transformers import GPTBigCodeForCausalLM + batch.input_ids=batch.input_ids[:,:1].fill_(fill_value) + batch.position_ids = batch.position_ids[:, -1:] + diff + if isinstance(self.model, GPTBigCodeForCausalLM): + batch.past_key_values=[ + torch.randn( + [ + len(batch), + batch.max_input_length - 1, + 2 * self.model.config.n_embd // self.model.config.n_head, + ], + dtype=cache_dtype, + device=batch.input_ids.device, + ) + for _ in range(self.model.config.n_layer) + ] + else: + batch.past_key_values=[ + [torch.randn( + [ + len(batch), + batch.max_input_length - 1, + self.model.config.n_embd // self.model.config.n_head, + ], + dtype=cache_dtype, + device=batch.input_ids.device, + ) + for _ in range(2)] for _ in range(self.model.config.n_layer) + ] + + @tracer.start_as_current_span("generate_token") def generate_token( self, batch: CausalLMBatch @@ -515,6 +572,9 @@ class CausalLM(Model): # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + #print("AAA", batch.max_input_length, batch.input_ids, batch.position_ids, batch.all_input_ids, batch.all_input_ids[0].shape) + #print("BBB", batch.max_input_length, batch.input_lengths, batch.padding_right_offset, batch.past_key_values is None) + logits, past = self.forward( batch.input_ids, attention_mask, diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py index 6d20001f..3bfd3669 100644 --- a/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py @@ -146,9 +146,10 @@ class GPTBigCodeAttention(nn.Module): hidden_states: torch.Tensor, layer_past: torch.Tensor, attention_mask: torch.Tensor, - batch_size: int, key_length: int, ) -> Tuple[torch.Tensor, Any]: + batch_size=hidden_states.size(0) + query, key_value = self.c_attn.forward(hidden_states).split( (self.embed_dim, 2 * self.head_dim), dim=-1 ) @@ -183,7 +184,6 @@ class GPTBigCodeAttention(nn.Module): unscale = self.layer_idx + 1 if upcast else 1 scale_factor = unscale**-1 / self.head_dim**0.5 - # TODO: No need to unsqueeze? hidden_states = torch.baddbmm( torch.empty( (batch_size, self.num_heads, padded_key_length), @@ -194,7 +194,7 @@ class GPTBigCodeAttention(nn.Module): key.transpose(-1, -2), beta=0, alpha=scale_factor, - ).unsqueeze_(1) + ) if upcast: hidden_states = upcast_masked_softmax( @@ -205,7 +205,7 @@ class GPTBigCodeAttention(nn.Module): hidden_states, attention_mask, self.mask_value ) - hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape) + hidden_states = torch.bmm(hidden_states, value).view(query.shape) hidden_states = self.c_proj.forward(hidden_states) @@ -265,7 +265,6 @@ class GPTBigCodeBlock(nn.Module): residual: Optional[torch.Tensor], layer_past: torch.Tensor, attention_mask: torch.Tensor, - batch_size: int, key_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, Any]: hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual) @@ -273,7 +272,6 @@ class GPTBigCodeBlock(nn.Module): hidden_states, layer_past=layer_past, attention_mask=attention_mask, - batch_size=batch_size, key_length=key_length, ) hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) @@ -370,12 +368,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta" ) - # Causal mask - self.causal_mask = torch.ones( - (config.max_position_embeddings, config.max_position_embeddings), - dtype=torch.bool, - device=device, - ).tril_() self.mask_value = torch.full( (), torch.finfo(torch.float32).min, dtype=torch.float32, device=device ) @@ -459,9 +451,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): position_ids: torch.Tensor, key_length: int, ) -> Tuple: - batch_size, query_length = input_ids.shape - assert query_length == 1 - hidden_states = self.transformer.wte.forward( input_ids ) + self.transformer.wpe.forward(position_ids) @@ -469,13 +458,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): # Standardize shape to (batch_size, hidden_size) hidden_states.squeeze_(1) - # Self-attention mask (padding + causal). - # TODO: Avoid unsqueeze - attention_mask = self.causal_mask[ - None, key_length - 1 : key_length, : attention_mask.size(-1) - ] * attention_mask.unsqueeze(1) - attention_mask.unsqueeze_(2) - residual = None block: GPTBigCodeBlock for i, (block, layer_past) in enumerate( @@ -486,11 +468,10 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): residual=residual, layer_past=layer_past, attention_mask=attention_mask, - batch_size=batch_size, key_length=key_length, ) hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual) - hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) + hidden_states = self.lm_head.forward(hidden_states) - return hidden_states, past_key_values + return hidden_states.unsqueeze_(1), past_key_values diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py index eb006fc7..4c224428 100644 --- a/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py @@ -13,12 +13,15 @@ # limitations under the License. """PyTorch GPTBigCode model.""" import math -from typing import Optional, Tuple, Any, Union, List -from enum import IntEnum +from typing import Optional, Tuple, Any, List -import torch -import torch.utils.checkpoint -from torch import nn +from torch import where, addmm, mm, float32, dtype, Tensor, baddbmm, empty, device, bmm, full, ones, finfo, jit +from torch.nn import Linear, Embedding, Module, LayerNorm, ModuleList +from torch.nn.functional import gelu, softmax, embedding + +from dropout_layer_norm import dropout_add_ln_fwd +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import flash_attn_unpadded_func from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging @@ -27,357 +30,307 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( ) -class InferenceRunnerType(IntEnum): - NO_RUNNER = 0 - # Use the inference runner without cuda graphs. - BASE_RUNNER = 1 - # Use cuda graphs in the inference runner. Leave out the attention which has a variable shape. - # This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models. - PARTIAL_GRAPH = 2 - # Turn the whole model into a cuda graph. One graph for each sequence length. - # Note: only useful for small batches and models, graphs take some time to generate, flaky. - # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. - FULL_GRAPH = 3 - - -try: - from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn.flash_attn_interface import flash_attn_unpadded_func -except ImportError: - flash_attn_unpadded_func = None - pad_input = None - unpad_input = None - - logger = logging.get_logger(__name__) -@torch.jit.script +@jit.script def upcast_masked_softmax( - x: torch.Tensor, - mask: torch.Tensor, - mask_value: torch.Tensor, - scale: float, - softmax_dtype: torch.dtype, + x: Tensor, mask: Tensor, mask_value: Tensor, scale: float ): input_dtype = x.dtype - x = x.to(softmax_dtype) * scale - x = torch.where(mask, x, mask_value) - x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + x = x.to(float32) * scale + x = where(mask, x, mask_value) + x = softmax(x, dim=-1).to(input_dtype) return x -@torch.jit.script -def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): - input_dtype = x.dtype - x = x.to(softmax_dtype) * scale - x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) - return x +class GPTBigCodeAttention(Module): + mask_value: Tensor - -@torch.jit.script -def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): - x = torch.where(mask, x, mask_value) - x = torch.nn.functional.softmax(x, dim=-1) - return x - - -@torch.profiler.record_function("softmax_function") -def softmax_function( - x: torch.Tensor, - mask: torch.Tensor, - mask_value: torch.Tensor, - scale: float, - softmax_dtype: torch.dtype, - upcast: bool = True, -): - """ - This selects the appropriate (fused) (upcast) (masked) softmax method. Because of the way jit works, each case - needs to be handled through a separate method. The fused kernels remove most of the overhead from masking, casting - and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely - inefficient. - """ - # assert x.size(-1) % 8 == 0 - if upcast: - if mask is None: - return upcast_softmax(x, scale, softmax_dtype) - else: - return upcast_masked_softmax(x, mask, mask_value, scale, softmax_dtype) - else: - if mask is None: - return torch.nn.functional.softmax(x, dim=-1) - else: - return masked_softmax(x, mask, mask_value) - - -class GPTBigCodeAttention(nn.Module): - def __init__( - self, - config, - layer_idx=None, - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu"), - ): + def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: dtype): super().__init__() - self.mask_value = None self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.layer_idx = layer_idx - # KV caching and padding - - self.c_attn = nn.Linear( + self.c_attn = Linear( self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, - device=device, + device="meta", ) - self.c_proj = nn.Linear( - self.embed_dim, self.embed_dim, dtype=dtype, device=device + self.c_proj = Linear( + self.embed_dim, self.embed_dim, dtype=dtype, device="meta" ) - @torch.profiler.record_function("GPTBigCodeAttention._get_mask_value") - def _get_mask_value(self, device, dtype): - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if ( - self.mask_value is None - or self.mask_value.dtype != dtype - or self.mask_value.device != device - ): - self.mask_value = torch.full( - [], torch.finfo(dtype).min, dtype=dtype, device=device +class GPTBigCodeMLP(Module): + # TODO: Merge into GPTBigCodeBlock (needs renaming in state dict) + def __init__(self, config: GPTBigCodeConfig, dtype: dtype): + super().__init__() + embed_dim = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim + self.c_fc = Linear(embed_dim, inner_dim, dtype=dtype, device="meta") + self.c_proj = Linear(inner_dim, embed_dim, dtype=dtype, device="meta") + + +class GPTBigCodeBlock(Module): + def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: dtype): + super().__init__() + self.ln_1 = LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device="meta", + ) + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype) + self.ln_2 = LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device="meta", + ) + self.mlp = GPTBigCodeMLP(config, dtype=dtype) + + def post_load_weights(self, mask_value): + self.attn.mask_value = mask_value + + self.mask_value = mask_value + self.hd=self.attn.head_dim + self.split0=(self.attn.embed_dim, 2*self.hd) + self.split1=(self.hd, self.hd) + + self.aaw=self.attn.c_attn.weight.t_() + self.aab=self.attn.c_attn.bias + self.apw=self.attn.c_proj.weight.t_() + self.apb=self.attn.c_proj.bias + + self.unscale=self.attn.layer_idx + 1 + self.ps=self.hd**-0.5 + self.ds=self.unscale**-1 * self.ps + + self.l1w=self.ln_1.weight + self.l1b=self.ln_1.bias + self.e1=self.ln_1.eps + self.l2w=self.ln_2.weight + self.l2b=self.ln_2.bias + self.e2=self.ln_2.eps + self.mfb=self.mlp.c_fc.bias + self.mfw=self.mlp.c_fc.weight.t_() + self.mfb=self.mlp.c_fc.bias + self.mpw=self.mlp.c_proj.weight.t_() + self.mpb=self.mlp.c_proj.bias + + + def prefill( + self, + hidden_states: Tensor, + residual: Optional[Tensor], + sequence_lengths, + key_length: int, + ) -> Tuple[Tensor, Tensor, Any]: + if residual is None: # First layer + residual = hidden_states + hidden_states, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l1w, + self.l1b, + None, + None, + None, + None, + 0.0, + self.e1, + 1.0, + 0, + None, + False, + False, ) - return self.mask_value - - @torch.profiler.record_function("GPTBigCodeAttention._attn") - def _attn(self, query, key, value, attention_mask): - softmax_dtype = torch.float32 - upcast = query.dtype != softmax_dtype - - unscale = self.layer_idx + 1 if upcast else 1 - scale_factor = unscale**-1 / self.head_dim**0.5 - - # (batch_size, query_length, num_heads * head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key_length = key.size(-2) - - key = key.transpose(-1, -2) - # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) - # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] - attn_shape = (batch_size, query_length, self.num_heads, key_length) - attn_view = (batch_size, query_length * self.num_heads, key_length) - # No copy needed for MQA 2, or when layer_past is provided. - query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) - - attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) - if query.device.type == "cpu": - # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. - # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, - # but the fix has not been released as of pytorch version 2.0.0. - attn_weights.zero_() - beta = 1 else: - beta = 0 - attn_weights = torch.baddbmm( - attn_weights, query, key, beta=beta, alpha=scale_factor - ).view(attn_shape) - - attn_weights = softmax_function( - attn_weights, - attention_mask, - None - if attention_mask is None - else self._get_mask_value(attn_weights.device, softmax_dtype), - unscale, - softmax_dtype, - upcast, + hidden_states, residual, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l1w, + self.l1b, + None, + None, + None, + None, + 0.0, + self.e1, + 1.0, + 0, + None, + False, + False, + ) + hidden_shape = hidden_states.shape + query, key_value = addmm(self.aab, hidden_states, self.aaw).split(self.split0, dim=-1) + query = query.view(hidden_shape[0], self.num_heads, self.head_dim) + key, value = ( + key_value.unsqueeze(1) + .expand(hidden_shape[0], self.num_heads, 2 * self.head_dim) + .split(self.split1, dim=-1) ) - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) - - return attn_output - - @torch.profiler.record_function("GPTBigCodeAttention._attn_flash") - def _attn_flash(self, query, key, value, flash_params): - query_shape = query.shape - attn_shape = query_shape[0], self.num_heads, self.head_dim - query = query.view(attn_shape) - key = key.unsqueeze(1).expand(attn_shape) - value = value.unsqueeze(1).expand(attn_shape) - - sequence_lengths, padding_index, _, max_sequence_length = flash_params # attn_output: (sum_seq_len, num_heads * head_dim) - attn_output = flash_attn_unpadded_func( + hidden_states = flash_attn_unpadded_func( query, key, value, sequence_lengths, sequence_lengths, - max_sequence_length, - max_sequence_length, + key_length, + key_length, 0.0, - softmax_scale=self.head_dim**-0.5, + softmax_scale=self.ps, causal=True, - ).view(query_shape) + ).view(hidden_shape) + hidden_states = addmm(self.apb, hidden_states, self.apw, out=query) - return attn_output + hidden_states, residual, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l2w, + self.l2b, + None, + None, + None, + None, + 0.0, + self.e2, + 1.0, + 0, + None, + False, + False, + ) + # TODO: Find an inplace and/or fused (with addmm) gelu kernel? + hidden_states = addmm(self.mpb, gelu(addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states) + return hidden_states, residual, key_value - @torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches") - def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): - # Convert to standard KV cache format. - if flash_params is not None: - _, padding_index, batch_size, max_sequence_length = flash_params - current_kv_cache = pad_input( - key_value, padding_index, batch_size, max_sequence_length + def decode( + self, + hidden_states: Tensor, + residual: Optional[Tensor], + layer_past: Tensor, + attention_mask: Tensor, + key_length: int, + ) -> Tuple[Tensor, Tensor, Any]: + + batch_size=hidden_states.size(0) + + if residual is None: # First layer + residual = hidden_states + hidden_states, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l1w, + self.l1b, + None, + None, + None, + None, + 0.0, + self.e1, + 1.0, + 0, + None, + False, + False, + ) + else: + hidden_states, residual, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l1w, + self.l1b, + None, + None, + None, + None, + 0.0, + self.e1, + 1.0, + 0, + None, + False, + False, ) - return key_value, (current_kv_cache, max_sequence_length) - current_kv_cache = key_value + query, key_value = addmm(self.aab, hidden_states, self.aaw).split(self.split0, dim=-1) + query_view = query.view(batch_size, self.num_heads, self.head_dim) # Calculate dimensions and recover layer_past - batch_size = current_kv_cache.size(0) - query_length = current_kv_cache.size(-2) - if layer_past is None: - allocated_kv_cache, last_key_length = None, 0 - last_kv_cache = None - key_length = query_length - allocated_key_length = key_length - else: - allocated_kv_cache, last_key_length = layer_past - last_kv_cache = allocated_kv_cache[:, :last_key_length] - key_length = query_length + last_key_length - allocated_key_length = allocated_kv_cache.size(-2) - padded_key_length = attention_mask.size(-1) + allocated_key_length = layer_past.size(-2) - # Re-allocate kv cache and copy last value + # TODO: Allow pre-allocation with size > padded_key_length if padded_key_length > allocated_key_length: - allocated_kv_cache = torch.empty( - [batch_size, padded_key_length, 2 * self.head_dim], - dtype=current_kv_cache.dtype, - device=current_kv_cache.device, + # Re-allocate kv cache and copy last value + allocated_kv_cache = empty( + [batch_size, padded_key_length, 2*self.hd], + dtype=key_value.dtype, + device=key_value.device, ) - if layer_past is not None: - allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) - if padded_key_length > key_length: - # Nans in `value` can propagate through the matrix multiplication, - # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) - allocated_kv_cache[:, key_length:, self.head_dim :].zero_() + allocated_kv_cache[:, : key_length - 1].copy_( + layer_past[:, : key_length - 1] + ) + # Nans in `value` can propagate through the matrix multiplication, + # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) + allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_() + layer_past = allocated_kv_cache # Copy the new values. - if padded_key_length > allocated_key_length or layer_past is not None: - allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) - padded_kv_cache = allocated_kv_cache[:, :padded_key_length] - # Use the merged KV cache. - # Not needed when layer_past is None but frees some memory. - key_value = padded_kv_cache + layer_past[:, key_length - 1].copy_(key_value) - if allocated_kv_cache is None: - allocated_kv_cache = current_kv_cache - present = allocated_kv_cache, key_length - return key_value, present + key, value = layer_past.split(self.split1, dim=-1) - @torch.profiler.record_function("GPTBigCodeAttention.forward") - def forward( - self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - flash_params: Optional[Tuple] = None, - ) -> Tuple[torch.Tensor, Any]: - query, key_value = self.c_attn(hidden_states).split( - (self.embed_dim, 2 * self.head_dim), dim=-1 + # Assume we always upcast (optimized for fp16/bf16) + # TODO: Upcasting needed for bf16? + hidden_states = baddbmm( + empty( + (batch_size, self.num_heads, padded_key_length), + device=query.device, + dtype=query.dtype, + ), + query_view, + key.transpose(-1, -2), + beta=0, + alpha=self.ds, + ) + hidden_states = upcast_masked_softmax( + hidden_states, attention_mask, self.mask_value, self.unscale ) - # present = (allocated_kv_cache, key_length) - key_value, present = self._merge_kv_caches( - key_value, layer_past, attention_mask, flash_params + # TODO: Write attn output directly into query, avoids both allocation and view. + bmm(hidden_states.squeeze_(1), value, out=query_view) + # TODO: Reuse attn weight tensor for c_proj output? + hidden_states = addmm(self.apb, query, self.apw) + + hidden_states, residual, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l2w, + self.l2b, + None, + None, + None, + None, + 0.0, + self.e2, + 1.0, + 0, + None, + False, + False, ) - - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - - if flash_params is None: - attn_output = self._attn(query, key, value, attention_mask) - else: - attn_output = self._attn_flash(query, key, value, flash_params) - - attn_output = self.c_proj(attn_output) - - return attn_output, present - - -class GPTBigCodeMLP(nn.Module): - def __init__(self, config): - super().__init__() - embed_dim = config.hidden_size - inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim - self.c_fc = nn.Linear(embed_dim, inner_dim) - self.c_proj = nn.Linear(inner_dim, embed_dim) - - @torch.profiler.record_function("GPTBigCodeMLP.forward") - # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward - def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: - return self.c_proj( - nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh") - ) - - -class GPTBigCodeBlock(nn.Module): - def __init__( - self, - config, - layer_idx=None, - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu"), - ): - super().__init__() - self.ln_1 = nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_epsilon, - dtype=dtype, - device=device, - ) - self.attn = GPTBigCodeAttention( - config, layer_idx=layer_idx, dtype=dtype, device=device - ) - self.ln_2 = nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_epsilon, - dtype=dtype, - device=device, - ) - self.mlp = GPTBigCodeMLP(config) - - @torch.profiler.record_function("GPTBigCodeBlock.forward") - def forward( - self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - flash_params: Optional[Tuple] = None, - ) -> Tuple[torch.Tensor, Any]: - with torch.profiler.record_function("GPTBigCodeAttention.ln"): - ai = self.ln_1(hidden_states) - attn_output, present = self.attn( - ai, - layer_past=layer_past, - attention_mask=attention_mask, - flash_params=flash_params, - ) - with torch.profiler.record_function("GPTBigCodeAttention.residual"): - hidden_states.add_(attn_output) - - with torch.profiler.record_function("GPTBigCodeAttention.dummy"): - pass - with torch.profiler.record_function("GPTBigCodeAttention.ln"): - ai = self.ln_2(hidden_states) - ai = self.mlp(ai) - with torch.profiler.record_function("GPTBigCodeAttention.residual"): - hidden_states.add_(ai) - return hidden_states, present + # TODO: Reuse attn weight tensor for c_fc output? (ok if padded_key_length>=4*head_dim, otherwise need to allocate a bigger one). + # TODO: Find an inplace and/or fused (with addmm) gelu kernel? + hidden_states = addmm(self.mpb, gelu(addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states) + return hidden_states, residual, layer_past class GPTBigCodePreTrainedModel(PreTrainedModel): @@ -396,9 +349,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, GPTBigCodeModel): - module.bias.fill_(True).tril_() - elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)): + if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. @@ -412,215 +363,188 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): ), ) module.c_proj._is_hf_initialized = True - elif isinstance(module, nn.Linear): + elif isinstance(module, Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): + elif isinstance(module, Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class GPTBigCodeModel(GPTBigCodePreTrainedModel): - def __init__( - self, - config, - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu"), - ): + # TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict) + def __init__(self, config: GPTBigCodeConfig, dtype: dtype): super().__init__(config) - self.wte = nn.Embedding( - config.vocab_size, config.hidden_size, dtype=dtype, device=device + self.wte = Embedding( + config.vocab_size, config.hidden_size, dtype=dtype, device="meta" ) - self.wpe = nn.Embedding( + self.wpe = Embedding( config.max_position_embeddings, config.hidden_size, dtype=dtype, - device=device, + device="meta", ) - self.h = nn.ModuleList( + self.h = ModuleList( [ - GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) + GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) for i in range(config.num_hidden_layers) ] ) - self.ln_f = nn.LayerNorm( + self.ln_f = LayerNorm( config.hidden_size, dtype=dtype, - device=device, + device="meta", eps=config.layer_norm_epsilon, ) - self.inference_runner_type = ( - InferenceRunnerType.NO_RUNNER - ) # InferenceRunnerType(config.inference_runner) - - self.flash_attention = True # config.flash_attention - - if self.flash_attention: - if flash_attn_unpadded_func is None: - raise RuntimeError( - "Flash Attention requires `flash_attn` and `einops`. " - "To install, run `pip install flash-attn einops`." - ) - - if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: - self.inference_runner = None - else: - from .inference_runner import GPTBigCodeInferenceRunner - - self.inference_runner = GPTBigCodeInferenceRunner(config, self) - - # Causal mask - self.register_buffer( - "bias", - torch.empty( - (config.max_position_embeddings, config.max_position_embeddings), - dtype=torch.bool, - device=device, - ), - ) - - # @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask") - def _get_causal_mask(self, padding_mask, query_length, key_length): - # Self-attention mask. - attention_mask = self.bias[ - None, key_length - query_length : key_length, :key_length - ] - - if padding_mask is not None: - attention_mask = attention_mask * padding_mask.unsqueeze(1).to( - dtype=torch.bool, device=attention_mask.device - ) - pad = -key_length % 8 - if pad > 0: - attention_mask = torch.nn.functional.pad( - attention_mask, (0, pad), mode="constant", value=False - ) - - # (batch_size, query_length, n_heads, key_length) - return attention_mask.unsqueeze(2) - - # @torch.profiler.record_function("GPTBigCodeModel.forward") - def forward( - self, - *, - input_ids: torch.Tensor, - past_key_values: Optional[Union[List[torch.Tensor], int]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: torch.Tensor, - ) -> Tuple: - if self.inference_runner is not None and past_key_values is not None: - if self.config.validate_runner_input: - assert past_key_values is not None - return self.inference_runner.forward( - input_ids, attention_mask, position_ids, past_key_values - ) - - batch_size, query_length = input_ids.shape - - flash_attention = self.flash_attention and past_key_values is None - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][1] - - hidden_states = self.wte(input_ids) + self.wpe(position_ids) - - # TODO: Unpad earlier (input ids), support unpadded input? - if flash_attention: - ( - hidden_states, - padding_index, - sequence_lengths, - max_sequence_length, - ) = unpad_input(hidden_states, attention_mask) - flash_params = ( - sequence_lengths, - padding_index, - batch_size, - max_sequence_length, - ) - attention_mask = None - else: - key_length = past_length + query_length - # Self-attention mask (padding + causal). - attention_mask = self._get_causal_mask( - attention_mask, query_length, key_length - ) - flash_params = None - - presents = [] - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - flash_params=flash_params, - ) - - hidden_states = outputs[0] - presents.append(outputs[1]) - - hidden_states = self.ln_f(hidden_states) - - if flash_attention: - hidden_states = pad_input( - hidden_states, padding_index, batch_size, query_length - ) - - return hidden_states, presents - class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): def __init__( - self, - config, - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu"), + self, config, dtype: dtype, device: device = device("cuda") ): super().__init__(config) - meta = torch.device("meta") - self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) - self.lm_head = nn.Linear( - config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta + if device.type != "cuda": + raise NotImplementedError(f"Device {device} not supported") + + self.transformer = GPTBigCodeModel(config, dtype=dtype) + self.lm_head = Linear( + config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta" + ) + self.mask_value = full( + (), finfo(float32).min, dtype=float32, device=device ) self.to_empty(device=device) # Initialize weights and apply final processing + # TODO: Skip? self.post_init() - # @torch.profiler.record_function("GPTBigCodeForCausalLM.forward") - def forward( + def post_load_weights(self): + layer: GPTBigCodeBlock + for layer in self.transformer.h: + layer.post_load_weights(self.mask_value) + self.tw=self.transformer.wte.weight + self.pw=self.transformer.wpe.weight + self.hw=self.lm_head.weight.t_() + self.lw=self.transformer.ln_f.weight + self.lb=self.transformer.ln_f.bias + self.le=self.transformer.ln_f.eps + + def prefill( self, *, - input_ids: torch.Tensor, - past_key_values: Optional[Union[List[torch.Tensor], int]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: torch.Tensor, + input_ids: Tensor, + attention_mask: Tensor = None, + position_ids: Tensor, predict_all_tokens: bool = True, ) -> Tuple: - hidden_states, presents = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, + batch_size, query_length = input_ids.shape + + hidden_states = embedding(input_ids, self.tw).add_(embedding(position_ids, self.pw)) + + # Prefill (flash attn) + # TODO: Unpad earlier (input ids)? + hidden_states, padding_index, sequence_lengths, key_length = unpad_input( + hidden_states, attention_mask ) - # with torch.profiler.record_function("GPTBigCodeForCausalLM.head"): - if not predict_all_tokens: - # We only care about the last token. - hidden_states = hidden_states[:, -1:] + assert key_length == query_length - lm_logits = self.lm_head(hidden_states) + residual = None + past_key_values = [] + block: GPTBigCodeBlock + for block in self.transformer.h: + hidden_states, residual, key_value = block.prefill( + hidden_states, + residual, + sequence_lengths, + key_length, + ) + past_key_values.append( + pad_input(key_value, padding_index, batch_size, query_length) + ) - return lm_logits, presents + hidden_states, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.lw, + self.lb, + None, + None, + None, + None, + 0.0, + self.le, + 1.0, + 0, + None, + False, + False, + ) + + # Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible. + del residual + + if predict_all_tokens: + hidden_states = pad_input( + mm(hidden_states, self.hw), padding_index, batch_size, query_length + ) + else: + # TODO: Index directly using cu_seqlens instead + hidden_states = mm(pad_input( + hidden_states, padding_index, batch_size, query_length + )[:, -1], self.hw).unsqueeze_(1) + + return hidden_states, past_key_values + + def decode( + self, + *, + input_ids: Tensor, + past_key_values: List[Tensor], + attention_mask: [Tensor], + position_ids: Tensor, + key_length: int, + ) -> Tuple: + hidden_states = embedding(input_ids, self.tw).add_(embedding(position_ids, self.pw)) + + residual = None + block: GPTBigCodeBlock + for i, (block, layer_past) in enumerate( + zip(self.transformer.h, past_key_values) + ): + hidden_states, residual, past_key_values[i] = block.decode( + hidden_states, + residual, + layer_past, + attention_mask, + key_length, + ) + hidden_states, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.lw, + self.lb, + None, + None, + None, + None, + 0.0, + self.le, + 1.0, + 0, + None, + False, + False, + ) + # TODO: Reuse residual? + return mm(hidden_states, self.hw).unsqueeze_(1), past_key_values diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode4_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode4_modeling.py new file mode 100644 index 00000000..cd000846 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode4_modeling.py @@ -0,0 +1,565 @@ +# coding=utf-8 +# Copyright 2023 The Bigcode team and HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTBigCode model.""" +import math +from typing import Optional, Tuple, Any, List + +import torch +import torch.utils.checkpoint +from torch import nn + +from dropout_layer_norm import dropout_add_ln_fwd +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( + GPTBigCodeConfig, +) + + +class FastLayerNorm(nn.LayerNorm): + # TODO: Validate dimension + def forward(self, hidden_states, residual=None): + out, residual, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + return out, residual + + +class FastLinear(nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.addmm(self.bias, input, self.weight) + + +class FastLinearNoBias(nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.mm(input, self.weight) + + +logger = logging.get_logger(__name__) + + +@torch.jit.script +def upcast_masked_softmax( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float +): + input_dtype = x.dtype + x = x.to(torch.float32) * scale + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1) + return x + + +class GPTBigCodeAttention(nn.Module): + mask_value: torch.Tensor + + def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.layer_idx = layer_idx + + self.c_attn = FastLinear( + self.embed_dim, + self.embed_dim + 2 * self.head_dim, + dtype=dtype, + device="meta", + ) + self.c_proj = FastLinear( + self.embed_dim, self.embed_dim, dtype=dtype, device="meta" + ) + + @torch.profiler.record_function("GPTBigCodeAttention.prefill") + def prefill( + self, + hidden_states: torch.Tensor, + sequence_lengths, + key_length: int, + ) -> Tuple[torch.Tensor, Any]: + hidden_shape = hidden_states.shape + query, key_value = self.c_attn.forward(hidden_states).split( + (self.embed_dim, 2 * self.head_dim), dim=-1 + ) + query = query.view(hidden_shape[0], self.num_heads, self.head_dim) + key, value = ( + key_value.unsqueeze(1) + .expand(hidden_shape[0], self.num_heads, 2 * self.head_dim) + .split((self.head_dim, self.head_dim), dim=-1) + ) + + # attn_output: (sum_seq_len, num_heads * head_dim) + hidden_states = flash_attn_unpadded_func( + query, + key, + value, + sequence_lengths, + sequence_lengths, + key_length, + key_length, + 0.0, + softmax_scale=self.head_dim**-0.5, + causal=True, + ).view(hidden_shape) + + hidden_states = self.c_proj.forward(hidden_states) + + return hidden_states, key_value + + @torch.profiler.record_function("GPTBigCodeAttention.decode") + def decode( + self, + hidden_states: torch.Tensor, + layer_past: torch.Tensor, + attention_mask: torch.Tensor, + batch_size: int, + key_length: int, + ) -> Tuple[torch.Tensor, Any]: + query, key_value = self.c_attn.forward(hidden_states).split( + (self.embed_dim, 2 * self.head_dim), dim=-1 + ) + + # Calculate dimensions and recover layer_past + padded_key_length = attention_mask.size(-1) + allocated_key_length = layer_past.size(-2) + + # TODO: Allow pre-allocation with size > padded_key_length + if padded_key_length > allocated_key_length: + # Re-allocate kv cache and copy last value + allocated_kv_cache = torch.empty( + [batch_size, padded_key_length, 2 * self.head_dim], + dtype=key_value.dtype, + device=key_value.device, + ) + allocated_kv_cache[:, : key_length - 1].copy_( + layer_past[:, : key_length - 1] + ) + # Nans in `value` can propagate through the matrix multiplication, + # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) + allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_() + layer_past = allocated_kv_cache + + # Copy the new values. + layer_past[:, key_length - 1].copy_(key_value) + + key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1) + + # TODO: Upcasting needed for bf16? + upcast = query.dtype != torch.float32 + unscale = self.layer_idx + 1 if upcast else 1 + scale_factor = unscale**-1 / self.head_dim**0.5 + + # TODO: No need to unsqueeze? + hidden_states = torch.baddbmm( + torch.empty( + (batch_size, self.num_heads, padded_key_length), + device=query.device, + dtype=query.dtype, + ), + query.view(batch_size, self.num_heads, self.head_dim), + key.transpose(-1, -2), + beta=0, + alpha=scale_factor, + ).unsqueeze_(1) + + if upcast: + hidden_states = upcast_masked_softmax( + hidden_states, attention_mask, self.mask_value, unscale + ) + else: + hidden_states = masked_softmax( + hidden_states, attention_mask, self.mask_value + ) + + hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape) + + hidden_states = self.c_proj.forward(hidden_states) + + return hidden_states, layer_past + + +class GPTBigCodeMLP(nn.Module): + # TODO: Merge into GPTBigCodeBlock (needs renaming in state dict) + def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype): + super().__init__() + embed_dim = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim + self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta") + self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta") + + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype): + super().__init__() + self.ln_1 = FastLayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device="meta", + ) + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype) + self.ln_2 = FastLayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device="meta", + ) + self.mlp = GPTBigCodeMLP(config, dtype=dtype) + + def post_load_weights(self, mask_value): + self.attn.mask_value = mask_value + self.attn.c_attn.weight.t_() + self.attn.c_proj.weight.t_() + + self.l1w=self.ln_1.weight + self.l1b=self.ln_1.bias + self.e1=self.ln_1.eps + self.l2w=self.ln_2.weight + self.l2b=self.ln_2.bias + self.e2=self.ln_2.eps + self.mfb=self.mlp.c_fc.bias + self.mfw=self.mlp.c_fc.weight.t_() + self.mfb=self.mlp.c_fc.bias + self.mpw=self.mlp.c_proj.weight.t_() + self.mpb=self.mlp.c_proj.bias + + + @torch.profiler.record_function("GPTBigCodeBlock.prefill") + def prefill( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + sequence_lengths, + key_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor, Any]: + hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual) + hidden_states, present = self.attn.prefill( + hidden_states, + sequence_lengths, + key_length, + ) + hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) + hidden_states = self.mlp.c_proj.forward( + nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh") + ) + return hidden_states, residual, present + + @torch.profiler.record_function("GPTBigCodeBlock.decode") + def decode( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + layer_past: torch.Tensor, + attention_mask: torch.Tensor, + batch_size: int, + key_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor, Any]: + if residual is None: + residual = hidden_states + hidden_states, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l1w, + self.l1b, + None, + None, + None, + None, + 0.0, + self.e1, + 1.0, + 0, + None, + False, + False, + ) + else: + hidden_states, residual, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l1w, + self.l1b, + None, + None, + None, + None, + 0.0, + self.e1, + 1.0, + 0, + None, + False, + False, + ) + + hidden_states, present = self.attn.decode( + hidden_states, + layer_past, + attention_mask, + batch_size, + key_length, + ) + hidden_states, residual, *_ = dropout_add_ln_fwd( + hidden_states, + residual, + self.l2w, + self.l2b, + None, + None, + None, + None, + 0.0, + self.e2, + 1.0, + 0, + None, + False, + False, + ) + hidden_states = torch.addmm(self.mpb, nn.functional.gelu(torch.addmm(self.mfb, hidden_states, self.mfw), approximate="tanh"), self.mpw, out=hidden_states) + return hidden_states, residual, present + + +class GPTBigCodePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTBigCodeConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = False + _no_split_modules = ["GPTBigCodeBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + module.c_proj.weight.data.normal_( + mean=0.0, + std=( + self.config.initializer_range / math.sqrt(2 * self.config.n_layer) + ), + ) + module.c_proj._is_hf_initialized = True + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class GPTBigCodeModel(GPTBigCodePreTrainedModel): + # TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict) + def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype): + super().__init__(config) + + self.wte = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=dtype, device="meta" + ) + self.wpe = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + dtype=dtype, + device="meta", + ) + + self.h = nn.ModuleList( + [ + GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = FastLayerNorm( + config.hidden_size, + dtype=dtype, + device="meta", + eps=config.layer_norm_epsilon, + ) + + +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): + def __init__( + self, config, dtype: torch.dtype, device: torch.device = torch.device("cuda") + ): + super().__init__(config) + if device.type != "cuda": + raise NotImplementedError(f"Device {device} not supported") + + self.transformer = GPTBigCodeModel(config, dtype=dtype) + self.lm_head = FastLinearNoBias( + config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta" + ) + + # Causal mask + self.causal_mask = torch.ones( + (config.max_position_embeddings, config.max_position_embeddings), + dtype=torch.bool, + device=device, + ).tril_() + self.mask_value = torch.full( + (), torch.finfo(torch.float32).min, dtype=torch.float32, device=device + ) + + self.to_empty(device=device) + + # Initialize weights and apply final processing + # TODO: Skip? + self.post_init() + + def prefill( + self, + *, + input_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + position_ids: torch.Tensor, + predict_all_tokens: bool = True, + ) -> Tuple: + batch_size, query_length = input_ids.shape + + hidden_states = self.transformer.wte.forward( + input_ids + ) + self.transformer.wpe.forward(position_ids) + + # Prefill (flash attn) + # TODO: Unpad earlier (input ids)? + hidden_states, padding_index, sequence_lengths, key_length = unpad_input( + hidden_states, attention_mask + ) + assert key_length == query_length + + residual = None + past_key_values = [] + block: GPTBigCodeBlock + for block in self.transformer.h: + hidden_states, residual, key_value = block.prefill( + hidden_states, + residual=residual, + sequence_lengths=sequence_lengths, + key_length=query_length, + ) + past_key_values.append( + pad_input(key_value, padding_index, batch_size, query_length) + ) + + hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual) + + # Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible. + del residual + + if predict_all_tokens: + hidden_states = self.lm_head.forward(hidden_states) + hidden_states = pad_input( + hidden_states, padding_index, batch_size, query_length + ) + else: + # TODO: Index directly instead + hidden_states = pad_input( + hidden_states, padding_index, batch_size, query_length + )[:, -1] + hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) + + return hidden_states, past_key_values + + def post_load_weights(self): + layer: GPTBigCodeBlock + for layer in self.transformer.h: + layer.post_load_weights(self.mask_value) + self.lm_head.weight.t_() + + def decode( + self, + *, + input_ids: torch.Tensor, + past_key_values: List[torch.Tensor], + attention_mask: [torch.Tensor], + position_ids: torch.Tensor, + key_length: int, + ) -> Tuple: + batch_size, query_length = input_ids.shape + assert query_length == 1 + + hidden_states = self.transformer.wte.forward( + input_ids + ) + self.transformer.wpe.forward(position_ids) + + # Standardize shape to (batch_size, hidden_size) + hidden_states.squeeze_(1) + + # Self-attention mask (padding + causal). + # TODO: Avoid unsqueeze + attention_mask = self.causal_mask[ + None, key_length - 1 : key_length, : attention_mask.size(-1) + ] * attention_mask.unsqueeze(1) + attention_mask.unsqueeze_(2) + + residual = None + block: GPTBigCodeBlock + for i, (block, layer_past) in enumerate( + zip(self.transformer.h, past_key_values) + ): + hidden_states, residual, past_key_values[i] = block.decode( + hidden_states, + residual=residual, + layer_past=layer_past, + attention_mask=attention_mask, + batch_size=batch_size, + key_length=key_length, + ) + + hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual) + hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) + + return hidden_states, past_key_values diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 68086b8c..b87d01e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -451,11 +451,17 @@ class FlashCausalLM(Model): def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]): diff = max_input_length - max(batch.input_lengths) + fill_value=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id for i in range(len(batch)): batch.input_lengths[i] += diff - batch.prefix_offsets[i] = 0 - batch.read_offsets[i] = 0 - batch.all_input_ids[i] = batch.all_input_ids[i] + [self.tokenizer.pad_token_id] * diff if diff >= 0 else batch.all_input_ids[i][:diff] + for _ in range(diff): + batch.all_input_ids[i].append(fill_value) + # For some reason just resetting the offsets makes things way slower. + _, batch.prefix_offsets[i], batch.read_offsets[i] = self.decode_token( + batch.all_input_ids[i], + batch.prefix_offsets[i], + batch.read_offsets[i], + ) # TODO: Bug!?! batch.stopping_criterias[i].current_tokens += diff @@ -473,19 +479,19 @@ class FlashCausalLM(Model): batch.past_key_values=None else: assert len(batch.all_input_ids_tensor)>0, "Must run prefill first" - batch.input_ids.fill_(self.tokenizer.pad_token_id) + batch.input_ids.fill_(fill_value) batch.position_ids += diff batch.cu_seqlens += diff * batch.cu_seqlens_q # TODO: Bug!?! - batch.max_seqlen += batch.max_seqlen + diff*len(batch) + batch.max_seqlen += diff*len(batch) for i in range(len(batch)): - batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(self.tokenizer.pad_token_id) + batch.all_input_ids_tensor[i][batch.input_lengths[i]-diff:batch.input_lengths[i]].fill_(fill_value) batch.past_key_values = batch.past_key_values = torch.randn( ( batch.past_key_values.shape[0], - batch.past_key_values.shape[1] + len(batch.requests), + batch.past_key_values.shape[1] + diff*len(batch.requests), *batch.past_key_values.shape[2:], ), device=batch.past_key_values.device, dtype= cache_dtype ) diff --git a/server/text_generation_server/models/gpt_bigcode2.py b/server/text_generation_server/models/gpt_bigcode2.py index 02670eac..ad50dcb8 100644 --- a/server/text_generation_server/models/gpt_bigcode2.py +++ b/server/text_generation_server/models/gpt_bigcode2.py @@ -13,8 +13,15 @@ from text_generation_server.models.vectorized_causal_lm import ( VectorizedCausalLMBatch, ) from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import ( - GPTBigCodeForCausalLM, + GPTBigCodeForCausalLM as GPTBigCode2ForCausalLM, ) +from text_generation_server.models.custom_modeling.gpt_bigcode3_modeling import ( + GPTBigCodeForCausalLM as GPTBigCode3ForCausalLM, +) +from text_generation_server.models.custom_modeling.gpt_bigcode4_modeling import ( + GPTBigCodeForCausalLM as GPTBigCode4ForCausalLM, +) +from transformers.modeling_utils import PreTrainedModel tracer = trace.get_tracer(__name__) @@ -96,8 +103,9 @@ class Bigcode2Batch(VectorizedCausalLMBatch): return len(self.requests) -class Bigcode2CausalLM(VectorizedCausalLM): - model: GPTBigCodeForCausalLM +class Bigcode2CausalLMBase(VectorizedCausalLM): + #model: GPTBigCode2ForCausalLM + _model_class:Type[PreTrainedModel] def __init__( self, @@ -118,7 +126,7 @@ class Bigcode2CausalLM(VectorizedCausalLM): tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - model = GPTBigCodeForCausalLM.from_pretrained( + model = self._model_class.from_pretrained( model_id, revision=revision, torch_dtype=dtype, @@ -163,18 +171,18 @@ class Bigcode2CausalLM(VectorizedCausalLM): padded_key_length = ( key_length + -key_length % batch.pad_key_length_to_multiple ) - input_ids = batch.input_ids[:, key_length - 1 : key_length] + input_ids = batch.input_ids[:, key_length - 1] # Model Forward logits, batch.past_key_values = self.model.decode( input_ids=input_ids, - attention_mask=batch.attention_mask[:, :padded_key_length], - position_ids=batch.position_ids[:, key_length - 1 : key_length], + attention_mask=batch.attention_mask[:, None, :padded_key_length], + position_ids=batch.position_ids[:, key_length - 1], past_key_values=batch.past_key_values, key_length=key_length, ) next_token_ids, logprobs = batch.next_token_chooser( - input_ids, logits, batch.details + input_ids.unsqueeze(1), logits.unsqueeze(1), batch.details ) # Update batch # TODO: Why do we need all input ids? @@ -201,3 +209,14 @@ class Bigcode2CausalLM(VectorizedCausalLM): ) for _ in range(self.model.config.n_layer) ] + +class Bigcode2CausalLM(Bigcode2CausalLMBase): + _model_class=GPTBigCode2ForCausalLM + + +class Bigcode3CausalLM(Bigcode2CausalLMBase): + _model_class=GPTBigCode3ForCausalLM + + +class Bigcode4CausalLM(Bigcode2CausalLMBase): + _model_class = GPTBigCode4ForCausalLM