diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py index 98beec92..6d20001f 100644 --- a/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode2_modeling.py @@ -13,14 +13,13 @@ # limitations under the License. """PyTorch GPTBigCode model.""" import math -from typing import Optional, Tuple, Any, Union, List -from enum import IntEnum +from typing import Optional, Tuple, Any, List import torch import torch.utils.checkpoint from torch import nn -from dropout_layer_norm import dropout_layer_norm +from dropout_layer_norm import dropout_add_ln_fwd from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_func @@ -30,10 +29,11 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( GPTBigCodeConfig, ) + class FastLayerNorm(nn.LayerNorm): # TODO: Validate dimension def forward(self, hidden_states, residual=None): - return dropout_layer_norm.dropout_add_ln_fwd( + out, residual, *_ = dropout_add_ln_fwd( hidden_states, residual, self.weight, @@ -50,18 +50,24 @@ class FastLayerNorm(nn.LayerNorm): False, False, ) + if residual is None: + residual = hidden_states + return out, residual class FastLinear(nn.Linear): def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.addmm(self.bias, input, self.weight) + class FastLinearNoBias(nn.Linear): def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.mm(input, self.weight) + logger = logging.get_logger(__name__) + @torch.jit.script def upcast_masked_softmax( x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float @@ -72,38 +78,53 @@ def upcast_masked_softmax( x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) return x + @torch.jit.script def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): x = torch.where(mask, x, mask_value) x = torch.nn.functional.softmax(x, dim=-1) return x + class GPTBigCodeAttention(nn.Module): - def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype): + mask_value: torch.Tensor + + def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype): super().__init__() self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.layer_idx = layer_idx - # Note: Does not support module dtype conversion. - self.register_buffer("mask_value", torch.empty((), dtype=torch.float32, device="meta")) - self.c_attn = FastLinear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device="meta") - self.c_proj = FastLinear(self.embed_dim, self.embed_dim, dtype=dtype, device="meta") + self.c_attn = FastLinear( + self.embed_dim, + self.embed_dim + 2 * self.head_dim, + dtype=dtype, + device="meta", + ) + self.c_proj = FastLinear( + self.embed_dim, self.embed_dim, dtype=dtype, device="meta" + ) def prefill( self, hidden_states: torch.Tensor, sequence_lengths, - key_length:int, + key_length: int, ) -> Tuple[torch.Tensor, Any]: hidden_shape = hidden_states.shape - query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) + query, key_value = self.c_attn.forward(hidden_states).split( + (self.embed_dim, 2 * self.head_dim), dim=-1 + ) query = query.view(hidden_shape[0], self.num_heads, self.head_dim) - key, value = key_value.unsqueeze(1).expand(hidden_shape[0], self.num_heads, 2*self.head_dim).split((self.head_dim, self.head_dim), dim=-1) + key, value = ( + key_value.unsqueeze(1) + .expand(hidden_shape[0], self.num_heads, 2 * self.head_dim) + .split((self.head_dim, self.head_dim), dim=-1) + ) # attn_output: (sum_seq_len, num_heads * head_dim) - attn_output = flash_attn_unpadded_func( + hidden_states = flash_attn_unpadded_func( query, key, value, @@ -116,23 +137,25 @@ class GPTBigCodeAttention(nn.Module): causal=True, ).view(hidden_shape) - attn_output = self.c_proj.forward(attn_output) + hidden_states = self.c_proj.forward(hidden_states) - return attn_output, key_value + return hidden_states, key_value def decode( self, hidden_states: torch.Tensor, layer_past: torch.Tensor, attention_mask: torch.Tensor, - batch_size:int, - key_length:int, + batch_size: int, + key_length: int, ) -> Tuple[torch.Tensor, Any]: - query, key_value = self.c_attn.forward(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) + query, key_value = self.c_attn.forward(hidden_states).split( + (self.embed_dim, 2 * self.head_dim), dim=-1 + ) # Calculate dimensions and recover layer_past padded_key_length = attention_mask.size(-1) - allocated_key_length=layer_past.size(-2) + allocated_key_length = layer_past.size(-2) # TODO: Allow pre-allocation with size > padded_key_length if padded_key_length > allocated_key_length: @@ -142,37 +165,45 @@ class GPTBigCodeAttention(nn.Module): dtype=key_value.dtype, device=key_value.device, ) - allocated_kv_cache[:, :key_length-1].copy_(layer_past[:, :key_length-1]) + allocated_kv_cache[:, : key_length - 1].copy_( + layer_past[:, : key_length - 1] + ) # Nans in `value` can propagate through the matrix multiplication, # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) allocated_kv_cache[:, allocated_key_length:, self.head_dim :].zero_() - layer_past=allocated_kv_cache + layer_past = allocated_kv_cache # Copy the new values. - layer_past[:, key_length-1:key_length].copy_(key_value) + layer_past[:, key_length - 1].copy_(key_value) key, value = layer_past.split((self.head_dim, self.head_dim), dim=-1) # TODO: Upcasting needed for bf16? upcast = query.dtype != torch.float32 unscale = self.layer_idx + 1 if upcast else 1 - scale_factor = unscale ** -1 / self.head_dim ** 0.5 + scale_factor = unscale**-1 / self.head_dim**0.5 + # TODO: No need to unsqueeze? hidden_states = torch.baddbmm( - torch.empty((batch_size, self.num_heads, padded_key_length), device=query.device, dtype=query.dtype), + torch.empty( + (batch_size, self.num_heads, padded_key_length), + device=query.device, + dtype=query.dtype, + ), query.view(batch_size, self.num_heads, self.head_dim), key.transpose(-1, -2), beta=0, - alpha=scale_factor + alpha=scale_factor, ).unsqueeze_(1) - if self.mask_value is None or self.mask_value.device != hidden_states.device: - self.mask_value = torch.full([], torch.finfo(torch.float32).min, dtype=torch.float32, device=hidden_states.device) - if upcast: - hidden_states = upcast_masked_softmax(hidden_states, attention_mask, self.mask_value, unscale) + hidden_states = upcast_masked_softmax( + hidden_states, attention_mask, self.mask_value, unscale + ) else: - hidden_states = masked_softmax(hidden_states, attention_mask, self.mask_value) + hidden_states = masked_softmax( + hidden_states, attention_mask, self.mask_value + ) hidden_states = torch.bmm(hidden_states.squeeze_(1), value).view(query.shape) @@ -180,21 +211,33 @@ class GPTBigCodeAttention(nn.Module): return hidden_states, layer_past + class GPTBigCodeMLP(nn.Module): # TODO: Merge into GPTBigCodeBlock (needs renaming in state dict) - def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype): + def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype): super().__init__() embed_dim = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * embed_dim self.c_fc = FastLinear(embed_dim, inner_dim, dtype=dtype, device="meta") self.c_proj = FastLinear(inner_dim, embed_dim, dtype=dtype, device="meta") + class GPTBigCodeBlock(nn.Module): - def __init__(self, config:GPTBigCodeConfig, layer_idx:int, dtype:torch.dtype): + def __init__(self, config: GPTBigCodeConfig, layer_idx: int, dtype: torch.dtype): super().__init__() - self.ln_1 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta") + self.ln_1 = FastLayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device="meta", + ) self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype) - self.ln_2 = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device="meta") + self.ln_2 = FastLayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device="meta", + ) self.mlp = GPTBigCodeMLP(config, dtype=dtype) def prefill( @@ -202,7 +245,7 @@ class GPTBigCodeBlock(nn.Module): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], sequence_lengths, - key_length:int, + key_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, Any]: hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual) hidden_states, present = self.attn.prefill( @@ -211,7 +254,9 @@ class GPTBigCodeBlock(nn.Module): key_length=key_length, ) hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) - hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")) + hidden_states = self.mlp.c_proj.forward( + nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh") + ) return hidden_states, residual, present def decode( @@ -220,8 +265,8 @@ class GPTBigCodeBlock(nn.Module): residual: Optional[torch.Tensor], layer_past: torch.Tensor, attention_mask: torch.Tensor, - batch_size:int, - key_length:int, + batch_size: int, + key_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, Any]: hidden_states, residual, *_ = self.ln_1.forward(hidden_states, residual) hidden_states, present = self.attn.decode( @@ -232,7 +277,9 @@ class GPTBigCodeBlock(nn.Module): key_length=key_length, ) hidden_states, residual, *_ = self.ln_2.forward(hidden_states, residual) - hidden_states = self.mlp.c_proj.forward(nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh")) + hidden_states = self.mlp.c_proj.forward( + nn.functional.gelu(self.mlp.c_fc.forward(hidden_states), approximate="tanh") + ) return hidden_states, residual, present @@ -252,11 +299,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, GPTBigCodeModel): - module.bias.fill_(True).tril_() - elif isinstance(module, (GPTBigCodeBlock, GPTBigCodeAttention)): - if isinstance(module, GPTBigCodeAttention): - module.mask_value.fill_(torch.finfo(torch.float32).min) + if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. @@ -264,7 +307,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py module.c_proj.weight.data.normal_( - mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + mean=0.0, + std=( + self.config.initializer_range / math.sqrt(2 * self.config.n_layer) + ), ) module.c_proj._is_hf_initialized = True elif isinstance(module, nn.Linear): @@ -284,62 +330,86 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): class GPTBigCodeModel(GPTBigCodePreTrainedModel): # TODO: Merge into GPTBigCodeForCausalLM (needs renaming in state dict) - def __init__(self, config:GPTBigCodeConfig, dtype:torch.dtype): + def __init__(self, config: GPTBigCodeConfig, dtype: torch.dtype): super().__init__(config) - self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device="meta") - self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device="meta") - - self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) for i in range(config.num_hidden_layers)]) - self.ln_f = FastLayerNorm(config.hidden_size, dtype=dtype, device="meta", eps=config.layer_norm_epsilon) - - # Causal mask - self.register_buffer( - "causal_mask", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device="meta") + self.wte = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=dtype, device="meta" + ) + self.wpe = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + dtype=dtype, + device="meta", ) -class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): - pad_key_length_to_multiple=8 + self.h = nn.ModuleList( + [ + GPTBigCodeBlock(config, layer_idx=i, dtype=dtype) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = FastLayerNorm( + config.hidden_size, + dtype=dtype, + device="meta", + eps=config.layer_norm_epsilon, + ) - def __init__(self, config, dtype:torch.dtype, device:torch.device=torch.device("cuda")): + +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): + def __init__( + self, config, dtype: torch.dtype, device: torch.device = torch.device("cuda") + ): super().__init__(config) - if device.type!="cuda": + if device.type != "cuda": raise NotImplementedError(f"Device {device} not supported") self.transformer = GPTBigCodeModel(config, dtype=dtype) - self.lm_head = FastLinearNoBias(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta") + self.lm_head = FastLinearNoBias( + config.n_embd, config.vocab_size, bias=False, dtype=dtype, device="meta" + ) + + # Causal mask + self.causal_mask = torch.ones( + (config.max_position_embeddings, config.max_position_embeddings), + dtype=torch.bool, + device=device, + ).tril_() + self.mask_value = torch.full( + (), torch.finfo(torch.float32).min, dtype=torch.float32, device=device + ) self.to_empty(device=device) - self._apply=self._apply_not_allowed - # Initialize weights and apply final processing + # TODO: Skip? self.post_init() - def _apply_not_allowed(self): - # Dtype or device conversion would break the model. - raise NotImplementedError("Device or dtype conversion not supported!") - def prefill( self, *, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, position_ids: torch.Tensor, - predict_all_tokens: bool=True, + predict_all_tokens: bool = True, ) -> Tuple: batch_size, query_length = input_ids.shape - hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids) + hidden_states = self.transformer.wte.forward( + input_ids + ) + self.transformer.wpe.forward(position_ids) # Prefill (flash attn) # TODO: Unpad earlier (input ids)? - hidden_states, padding_index, sequence_lengths, key_length = unpad_input(hidden_states, attention_mask) - assert key_length==query_length + hidden_states, padding_index, sequence_lengths, key_length = unpad_input( + hidden_states, attention_mask + ) + assert key_length == query_length residual = None past_key_values = [] - block:GPTBigCodeBlock + block: GPTBigCodeBlock for block in self.transformer.h: hidden_states, residual, key_value = block.prefill( hidden_states, @@ -347,23 +417,39 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): sequence_lengths=sequence_lengths, key_length=query_length, ) - past_key_values.append(pad_input(key_value, padding_index, batch_size, query_length)) + past_key_values.append( + pad_input(key_value, padding_index, batch_size, query_length) + ) - hidden_states = self.transformer.ln_f.forward(hidden_states, residual) + hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual) # Next bit is the memory bottleneck with predict_all_tokens so we free as much memory as possible. del residual if predict_all_tokens: hidden_states = self.lm_head.forward(hidden_states) - hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) + hidden_states = pad_input( + hidden_states, padding_index, batch_size, query_length + ) else: # TODO: Index directly instead - hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)[:, -1] + hidden_states = pad_input( + hidden_states, padding_index, batch_size, query_length + )[:, -1] hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) return hidden_states, past_key_values + def post_load_weights(self): + layer: GPTBigCodeBlock + for layer in self.transformer.h: + layer.attn.mask_value = self.mask_value + layer.attn.c_attn.weight.t_() + layer.attn.c_proj.weight.t_() + layer.mlp.c_fc.weight.t_() + layer.mlp.c_proj.weight.t_() + self.lm_head.weight.t_() + def decode( self, *, @@ -371,26 +457,30 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): past_key_values: List[torch.Tensor], attention_mask: [torch.Tensor], position_ids: torch.Tensor, - key_length:int, + key_length: int, ) -> Tuple: - batch_size, query_length = input_ids.shape assert query_length == 1 - hidden_states = self.transformer.wte.forward(input_ids) + self.transformer.wpe.forward(position_ids) + hidden_states = self.transformer.wte.forward( + input_ids + ) + self.transformer.wpe.forward(position_ids) # Standardize shape to (batch_size, hidden_size) hidden_states.squeeze_(1) # Self-attention mask (padding + causal). # TODO: Avoid unsqueeze - attention_mask = self.transformer.causal_mask[None, key_length - 1: key_length, - :key_length] * attention_mask.unsqueeze(1) + attention_mask = self.causal_mask[ + None, key_length - 1 : key_length, : attention_mask.size(-1) + ] * attention_mask.unsqueeze(1) attention_mask.unsqueeze_(2) residual = None - block:GPTBigCodeBlock - for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): + block: GPTBigCodeBlock + for i, (block, layer_past) in enumerate( + zip(self.transformer.h, past_key_values) + ): hidden_states, residual, past_key_values[i] = block.decode( hidden_states, residual=residual, @@ -400,7 +490,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): key_length=key_length, ) - hidden_states = self.transformer.ln_f.forward(hidden_states, residual) + hidden_states, *_ = self.transformer.ln_f.forward(hidden_states, residual) hidden_states = self.lm_head.forward(hidden_states).unsqueeze_(1) return hidden_states, past_key_values diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py index 63350843..eb006fc7 100644 --- a/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode3_modeling.py @@ -26,6 +26,7 @@ from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( GPTBigCodeConfig, ) + class InferenceRunnerType(IntEnum): NO_RUNNER = 0 # Use the inference runner without cuda graphs. @@ -38,6 +39,7 @@ class InferenceRunnerType(IntEnum): # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. FULL_GRAPH = 3 + try: from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_func @@ -52,7 +54,11 @@ logger = logging.get_logger(__name__) @torch.jit.script def upcast_masked_softmax( - x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype + x: torch.Tensor, + mask: torch.Tensor, + mask_value: torch.Tensor, + scale: float, + softmax_dtype: torch.dtype, ): input_dtype = x.dtype x = x.to(softmax_dtype) * scale @@ -91,7 +97,7 @@ def softmax_function( and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely inefficient. """ - #assert x.size(-1) % 8 == 0 + # assert x.size(-1) % 8 == 0 if upcast: if mask is None: return upcast_softmax(x, scale, softmax_dtype) @@ -105,7 +111,13 @@ def softmax_function( class GPTBigCodeAttention(nn.Module): - def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + layer_idx=None, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__() self.mask_value = None self.embed_dim = config.hidden_size @@ -115,14 +127,27 @@ class GPTBigCodeAttention(nn.Module): # KV caching and padding - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device) - self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device) + self.c_attn = nn.Linear( + self.embed_dim, + self.embed_dim + 2 * self.head_dim, + dtype=dtype, + device=device, + ) + self.c_proj = nn.Linear( + self.embed_dim, self.embed_dim, dtype=dtype, device=device + ) @torch.profiler.record_function("GPTBigCodeAttention._get_mask_value") def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + if ( + self.mask_value is None + or self.mask_value.dtype != dtype + or self.mask_value.device != device + ): + self.mask_value = torch.full( + [], torch.finfo(dtype).min, dtype=dtype, device=device + ) return self.mask_value @torch.profiler.record_function("GPTBigCodeAttention._attn") @@ -156,12 +181,16 @@ class GPTBigCodeAttention(nn.Module): beta = 1 else: beta = 0 - attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + attn_weights = torch.baddbmm( + attn_weights, query, key, beta=beta, alpha=scale_factor + ).view(attn_shape) attn_weights = softmax_function( attn_weights, attention_mask, - None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), + None + if attention_mask is None + else self._get_mask_value(attn_weights.device, softmax_dtype), unscale, softmax_dtype, upcast, @@ -172,7 +201,6 @@ class GPTBigCodeAttention(nn.Module): @torch.profiler.record_function("GPTBigCodeAttention._attn_flash") def _attn_flash(self, query, key, value, flash_params): - query_shape = query.shape attn_shape = query_shape[0], self.num_heads, self.head_dim query = query.view(attn_shape) @@ -199,11 +227,12 @@ class GPTBigCodeAttention(nn.Module): @torch.profiler.record_function("GPTBigCodeAttention._merge_kv_caches") def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): - # Convert to standard KV cache format. if flash_params is not None: _, padding_index, batch_size, max_sequence_length = flash_params - current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) + current_kv_cache = pad_input( + key_value, padding_index, batch_size, max_sequence_length + ) return key_value, (current_kv_cache, max_sequence_length) current_kv_cache = key_value @@ -257,19 +286,23 @@ class GPTBigCodeAttention(nn.Module): hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - flash_params: Optional[Tuple] = None + flash_params: Optional[Tuple] = None, ) -> Tuple[torch.Tensor, Any]: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) + query, key_value = self.c_attn(hidden_states).split( + (self.embed_dim, 2 * self.head_dim), dim=-1 + ) # present = (allocated_kv_cache, key_length) - key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params) + key_value, present = self._merge_kv_caches( + key_value, layer_past, attention_mask, flash_params + ) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) if flash_params is None: - attn_output=self._attn(query, key, value, attention_mask) + attn_output = self._attn(query, key, value, attention_mask) else: - attn_output=self._attn_flash(query, key, value, flash_params) + attn_output = self._attn_flash(query, key, value, flash_params) attn_output = self.c_proj(attn_output) @@ -287,15 +320,35 @@ class GPTBigCodeMLP(nn.Module): @torch.profiler.record_function("GPTBigCodeMLP.forward") # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: - return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")) + return self.c_proj( + nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh") + ) class GPTBigCodeBlock(nn.Module): - def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + layer_idx=None, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__() - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) - self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device) - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) + self.ln_1 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device, + ) + self.attn = GPTBigCodeAttention( + config, layer_idx=layer_idx, dtype=dtype, device=device + ) + self.ln_2 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device, + ) self.mlp = GPTBigCodeMLP(config) @torch.profiler.record_function("GPTBigCodeBlock.forward") @@ -304,10 +357,10 @@ class GPTBigCodeBlock(nn.Module): hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - flash_params: Optional[Tuple] = None + flash_params: Optional[Tuple] = None, ) -> Tuple[torch.Tensor, Any]: with torch.profiler.record_function("GPTBigCodeAttention.ln"): - ai=self.ln_1(hidden_states) + ai = self.ln_1(hidden_states) attn_output, present = self.attn( ai, layer_past=layer_past, @@ -320,14 +373,13 @@ class GPTBigCodeBlock(nn.Module): with torch.profiler.record_function("GPTBigCodeAttention.dummy"): pass with torch.profiler.record_function("GPTBigCodeAttention.ln"): - ai=self.ln_2(hidden_states) - ai=self.mlp(ai) + ai = self.ln_2(hidden_states) + ai = self.mlp(ai) with torch.profiler.record_function("GPTBigCodeAttention.residual"): hidden_states.add_(ai) return hidden_states, present - class GPTBigCodePreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -354,7 +406,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py module.c_proj.weight.data.normal_( - mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + mean=0.0, + std=( + self.config.initializer_range / math.sqrt(2 * self.config.n_layer) + ), ) module.c_proj._is_hf_initialized = True elif isinstance(module, nn.Linear): @@ -373,18 +428,42 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): class GPTBigCodeModel(GPTBigCodePreTrainedModel): - def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__(config) - self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device) - self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device) + self.wte = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=dtype, device=device + ) + self.wpe = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + dtype=dtype, + device=device, + ) - self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)]) - self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon) + self.h = nn.ModuleList( + [ + GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = nn.LayerNorm( + config.hidden_size, + dtype=dtype, + device=device, + eps=config.layer_norm_epsilon, + ) - self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner) + self.inference_runner_type = ( + InferenceRunnerType.NO_RUNNER + ) # InferenceRunnerType(config.inference_runner) - self.flash_attention = True #config.flash_attention + self.flash_attention = True # config.flash_attention if self.flash_attention: if flash_attn_unpadded_func is None: @@ -402,13 +481,20 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): # Causal mask self.register_buffer( - "bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device) + "bias", + torch.empty( + (config.max_position_embeddings, config.max_position_embeddings), + dtype=torch.bool, + device=device, + ), ) - #@torch.profiler.record_function("GPTBigCodeModel._get_causal_mask") + # @torch.profiler.record_function("GPTBigCodeModel._get_causal_mask") def _get_causal_mask(self, padding_mask, query_length, key_length): # Self-attention mask. - attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + attention_mask = self.bias[ + None, key_length - query_length : key_length, :key_length + ] if padding_mask is not None: attention_mask = attention_mask * padding_mask.unsqueeze(1).to( @@ -416,12 +502,14 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ) pad = -key_length % 8 if pad > 0: - attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad), mode="constant", value=False + ) # (batch_size, query_length, n_heads, key_length) return attention_mask.unsqueeze(2) - #@torch.profiler.record_function("GPTBigCodeModel.forward") + # @torch.profiler.record_function("GPTBigCodeModel.forward") def forward( self, *, @@ -433,7 +521,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): if self.inference_runner is not None and past_key_values is not None: if self.config.validate_runner_input: assert past_key_values is not None - return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) + return self.inference_runner.forward( + input_ids, attention_mask, position_ids, past_key_values + ) batch_size, query_length = input_ids.shape @@ -444,30 +534,38 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): else: past_length = past_key_values[0][1] - hidden_states = self.wte(input_ids) + self.wpe(position_ids) # TODO: Unpad earlier (input ids), support unpadded input? if flash_attention: - hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( - hidden_states, attention_mask + ( + hidden_states, + padding_index, + sequence_lengths, + max_sequence_length, + ) = unpad_input(hidden_states, attention_mask) + flash_params = ( + sequence_lengths, + padding_index, + batch_size, + max_sequence_length, ) - flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length) - attention_mask=None + attention_mask = None else: key_length = past_length + query_length # Self-attention mask (padding + causal). - attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) - flash_params=None + attention_mask = self._get_causal_mask( + attention_mask, query_length, key_length + ) + flash_params = None presents = [] for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, - flash_params=flash_params + flash_params=flash_params, ) hidden_states = outputs[0] @@ -476,24 +574,33 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): hidden_states = self.ln_f(hidden_states) if flash_attention: - hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) + hidden_states = pad_input( + hidden_states, padding_index, batch_size, query_length + ) return hidden_states, presents class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): - def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__(config) - meta=torch.device("meta") + meta = torch.device("meta") self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta) + self.lm_head = nn.Linear( + config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta + ) self.to_empty(device=device) # Initialize weights and apply final processing self.post_init() - #@torch.profiler.record_function("GPTBigCodeForCausalLM.forward") + # @torch.profiler.record_function("GPTBigCodeForCausalLM.forward") def forward( self, *, @@ -501,16 +608,15 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): past_key_values: Optional[Union[List[torch.Tensor], int]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: torch.Tensor, - predict_all_tokens: bool=True, + predict_all_tokens: bool = True, ) -> Tuple: - - hidden_states, presents=self.transformer( + hidden_states, presents = self.transformer( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, - position_ids=position_ids + position_ids=position_ids, ) - #with torch.profiler.record_function("GPTBigCodeForCausalLM.head"): + # with torch.profiler.record_function("GPTBigCodeForCausalLM.head"): if not predict_all_tokens: # We only care about the last token. hidden_states = hidden_states[:, -1:] diff --git a/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py b/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py index 7ea367d7..d72afe9c 100644 --- a/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py +++ b/server/text_generation_server/models/custom_modeling/gpt_bigcode_modeling.py @@ -20,14 +20,13 @@ import torch import torch.utils.checkpoint from torch import nn -import dropout_layer_norm - from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( GPTBigCodeConfig, ) + class InferenceRunnerType(IntEnum): NO_RUNNER = 0 # Use the inference runner without cuda graphs. @@ -41,7 +40,6 @@ class InferenceRunnerType(IntEnum): FULL_GRAPH = 3 - try: from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_func @@ -56,7 +54,11 @@ logger = logging.get_logger(__name__) @torch.jit.script def upcast_masked_softmax( - x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype + x: torch.Tensor, + mask: torch.Tensor, + mask_value: torch.Tensor, + scale: float, + softmax_dtype: torch.dtype, ): input_dtype = x.dtype x = x.to(softmax_dtype) * scale @@ -94,7 +96,7 @@ def softmax_function( and scaling, but only work well when the key length is a multiple of 8. For other key lengths, it is extremely inefficient. """ - #assert x.size(-1) % 8 == 0 + # assert x.size(-1) % 8 == 0 if upcast: if mask is None: return upcast_softmax(x, scale, softmax_dtype) @@ -108,7 +110,13 @@ def softmax_function( class GPTBigCodeAttention(nn.Module): - def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + layer_idx=None, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__() self.mask_value = None self.embed_dim = config.hidden_size @@ -118,13 +126,26 @@ class GPTBigCodeAttention(nn.Module): # KV caching and padding - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim, dtype=dtype, device=device) - self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=dtype, device=device) + self.c_attn = nn.Linear( + self.embed_dim, + self.embed_dim + 2 * self.head_dim, + dtype=dtype, + device=device, + ) + self.c_proj = nn.Linear( + self.embed_dim, self.embed_dim, dtype=dtype, device=device + ) def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + if ( + self.mask_value is None + or self.mask_value.dtype != dtype + or self.mask_value.device != device + ): + self.mask_value = torch.full( + [], torch.finfo(dtype).min, dtype=dtype, device=device + ) return self.mask_value def _attn(self, query, key, value, attention_mask): @@ -157,12 +178,16 @@ class GPTBigCodeAttention(nn.Module): beta = 1 else: beta = 0 - attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + attn_weights = torch.baddbmm( + attn_weights, query, key, beta=beta, alpha=scale_factor + ).view(attn_shape) attn_weights = softmax_function( attn_weights, attention_mask, - None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), + None + if attention_mask is None + else self._get_mask_value(attn_weights.device, softmax_dtype), unscale, softmax_dtype, upcast, @@ -172,7 +197,6 @@ class GPTBigCodeAttention(nn.Module): return attn_output def _attn_flash(self, query, key, value, flash_params): - query_shape = query.shape attn_shape = query_shape[0], self.num_heads, self.head_dim query = query.view(attn_shape) @@ -198,11 +222,12 @@ class GPTBigCodeAttention(nn.Module): return attn_output def _merge_kv_caches(self, key_value, layer_past, attention_mask, flash_params): - # Convert to standard KV cache format. if flash_params is not None: _, padding_index, batch_size, max_sequence_length = flash_params - current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) + current_kv_cache = pad_input( + key_value, padding_index, batch_size, max_sequence_length + ) return key_value, (current_kv_cache, max_sequence_length) current_kv_cache = key_value @@ -255,19 +280,23 @@ class GPTBigCodeAttention(nn.Module): hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - flash_params: Optional[Tuple] = None + flash_params: Optional[Tuple] = None, ) -> Tuple[torch.Tensor, Any]: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.head_dim), dim=-1) + query, key_value = self.c_attn(hidden_states).split( + (self.embed_dim, 2 * self.head_dim), dim=-1 + ) # present = (allocated_kv_cache, key_length) - key_value, present = self._merge_kv_caches(key_value, layer_past, attention_mask, flash_params) + key_value, present = self._merge_kv_caches( + key_value, layer_past, attention_mask, flash_params + ) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) if flash_params is None: - attn_output=self._attn(query, key, value, attention_mask) + attn_output = self._attn(query, key, value, attention_mask) else: - attn_output=self._attn_flash(query, key, value, flash_params) + attn_output = self._attn_flash(query, key, value, flash_params) attn_output = self.c_proj(attn_output) @@ -284,15 +313,35 @@ class GPTBigCodeMLP(nn.Module): # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: - return self.c_proj(nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh")) + return self.c_proj( + nn.functional.gelu(self.c_fc(hidden_states), approximate="tanh") + ) class GPTBigCodeBlock(nn.Module): - def __init__(self, config, layer_idx=None, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + layer_idx=None, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__() - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) - self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx, dtype=dtype, device=device) - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=dtype, device=device) + self.ln_1 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device, + ) + self.attn = GPTBigCodeAttention( + config, layer_idx=layer_idx, dtype=dtype, device=device + ) + self.ln_2 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device, + ) self.mlp = GPTBigCodeMLP(config) def forward( @@ -300,7 +349,7 @@ class GPTBigCodeBlock(nn.Module): hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - flash_params: Optional[Tuple] = None + flash_params: Optional[Tuple] = None, ) -> Tuple[torch.Tensor, Any]: attn_output, present = self.attn( self.ln_1(hidden_states), @@ -313,7 +362,6 @@ class GPTBigCodeBlock(nn.Module): return hidden_states, present - class GPTBigCodePreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -340,7 +388,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py module.c_proj.weight.data.normal_( - mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + mean=0.0, + std=( + self.config.initializer_range / math.sqrt(2 * self.config.n_layer) + ), ) module.c_proj._is_hf_initialized = True elif isinstance(module, nn.Linear): @@ -359,18 +410,42 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): class GPTBigCodeModel(GPTBigCodePreTrainedModel): - def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__(config) - self.wte = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype, device=device) - self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size, dtype=dtype, device=device) + self.wte = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=dtype, device=device + ) + self.wpe = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + dtype=dtype, + device=device, + ) - self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) for i in range(config.num_hidden_layers)]) - self.ln_f = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device, eps=config.layer_norm_epsilon) + self.h = nn.ModuleList( + [ + GPTBigCodeBlock(config, layer_idx=i, dtype=dtype, device=device) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = nn.LayerNorm( + config.hidden_size, + dtype=dtype, + device=device, + eps=config.layer_norm_epsilon, + ) - self.inference_runner_type = InferenceRunnerType.NO_RUNNER #InferenceRunnerType(config.inference_runner) + self.inference_runner_type = ( + InferenceRunnerType.NO_RUNNER + ) # InferenceRunnerType(config.inference_runner) - self.flash_attention = True #config.flash_attention + self.flash_attention = True # config.flash_attention if self.flash_attention: if flash_attn_unpadded_func is None: @@ -388,12 +463,19 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): # Causal mask self.register_buffer( - "bias", torch.empty((config.max_position_embeddings, config.max_position_embeddings), dtype=torch.bool, device=device) + "bias", + torch.empty( + (config.max_position_embeddings, config.max_position_embeddings), + dtype=torch.bool, + device=device, + ), ) def _get_causal_mask(self, padding_mask, query_length, key_length): # Self-attention mask. - attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + attention_mask = self.bias[ + None, key_length - query_length : key_length, :key_length + ] if padding_mask is not None: attention_mask = attention_mask * padding_mask.unsqueeze(1).to( @@ -401,7 +483,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ) pad = -key_length % 8 if pad > 0: - attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad), mode="constant", value=False + ) # (batch_size, query_length, n_heads, key_length) return attention_mask.unsqueeze(2) @@ -417,7 +501,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): if self.inference_runner is not None and past_key_values is not None: if self.config.validate_runner_input: assert past_key_values is not None - return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) + return self.inference_runner.forward( + input_ids, attention_mask, position_ids, past_key_values + ) batch_size, query_length = input_ids.shape @@ -428,30 +514,38 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): else: past_length = past_key_values[0][1] - hidden_states = self.wte(input_ids) + self.wpe(position_ids) # TODO: Unpad earlier (input ids), support unpadded input? if flash_attention: - hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( - hidden_states, attention_mask + ( + hidden_states, + padding_index, + sequence_lengths, + max_sequence_length, + ) = unpad_input(hidden_states, attention_mask) + flash_params = ( + sequence_lengths, + padding_index, + batch_size, + max_sequence_length, ) - flash_params = (sequence_lengths, padding_index, batch_size, max_sequence_length) - attention_mask=None + attention_mask = None else: key_length = past_length + query_length # Self-attention mask (padding + causal). - attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) - flash_params=None + attention_mask = self._get_causal_mask( + attention_mask, query_length, key_length + ) + flash_params = None presents = [] for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, - flash_params=flash_params + flash_params=flash_params, ) hidden_states = outputs[0] @@ -460,17 +554,26 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): hidden_states = self.ln_f(hidden_states) if flash_attention: - hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) + hidden_states = pad_input( + hidden_states, padding_index, batch_size, query_length + ) return hidden_states, presents class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): - def __init__(self, config, dtype:torch.dtype=torch.float32, device:torch.device=torch.device("cpu")): + def __init__( + self, + config, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): super().__init__(config) - meta=torch.device("meta") + meta = torch.device("meta") self.transformer = GPTBigCodeModel(config, dtype=dtype, device=meta) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta) + self.lm_head = nn.Linear( + config.n_embd, config.vocab_size, bias=False, dtype=dtype, device=meta + ) self.to_empty(device=device) @@ -484,14 +587,13 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): past_key_values: Optional[Union[List[torch.Tensor], int]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: torch.Tensor, - predict_all_tokens: bool=True, + predict_all_tokens: bool = True, ) -> Tuple: - - hidden_states, presents=self.transformer( + hidden_states, presents = self.transformer( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, - position_ids=position_ids + position_ids=position_ids, ) if not predict_all_tokens: diff --git a/server/text_generation_server/models/gpt_bigcode.py b/server/text_generation_server/models/gpt_bigcode.py index ac321786..24d4df26 100644 --- a/server/text_generation_server/models/gpt_bigcode.py +++ b/server/text_generation_server/models/gpt_bigcode.py @@ -6,8 +6,13 @@ from opentelemetry import trace from transformers import AutoTokenizer from typing import Optional, Type -from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch -from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import GPTBigCodeForCausalLM +from text_generation_server.models.vectorized_causal_lm import ( + VectorizedCausalLM, + VectorizedCausalLMBatch, +) +from text_generation_server.models.custom_modeling.gpt_bigcode_modeling import ( + GPTBigCodeForCausalLM, +) tracer = trace.get_tracer(__name__) @@ -24,7 +29,9 @@ class BigcodeBatch(VectorizedCausalLMBatch): layer_kv.data = layer_kv[keep_indices, sequence_slice] @classmethod - def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): + def _concatenate_key_values( + cls, batches, start_indices, end_indices, left_indices, max_input_length + ): device = batches[0].input_ids.device batch_size = sum([len(batch.requests) for batch in batches]) @@ -35,13 +42,21 @@ class BigcodeBatch(VectorizedCausalLMBatch): past_key_values = [] for kv_caches in zip(*(batch.past_key_values for batch in batches)): key_values, seq_lengths = zip(*kv_caches) - assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths)) + assert all( + left_index + seq_length == max_input_length + for left_index, seq_length in zip(left_indices, seq_lengths) + ) - allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values)) - allocate_seq_len += - allocate_seq_len % 8 + allocate_seq_len = max( + left_index + key_value.size(1) + for left_index, key_value in zip(left_indices, key_values) + ) + allocate_seq_len += -allocate_seq_len % 8 kv_cache = torch.empty( - (batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device + (batch_size, allocate_seq_len, *key_values[0].shape[2:]), + dtype=key_values[0].dtype, + device=device, ) for key_value, start_index, end_index, left_index in zip( key_values, @@ -49,7 +64,9 @@ class BigcodeBatch(VectorizedCausalLMBatch): end_indices, left_indices, ): - kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value) + kv_cache[start_index:end_index, left_index:max_input_length].copy_( + key_value + ) # Set padding to zero to avoid propagating nans. kv_cache[start_index:end_index, :left_index].fill_(0) kv_cache[start_index:end_index, max_input_length:].fill_(0) @@ -58,6 +75,7 @@ class BigcodeBatch(VectorizedCausalLMBatch): def __len__(self): return len(self.requests) + class BigcodeCausalLM(VectorizedCausalLM): def __init__( self, @@ -103,7 +121,7 @@ class BigcodeCausalLM(VectorizedCausalLM): def batch_type(self) -> Type[BigcodeBatch]: return BigcodeBatch - def forward(self, batch:BigcodeBatch): + def forward(self, batch: BigcodeBatch): key_length = batch.max_input_length query_length = key_length if batch.past_key_values is None else 1 input_ids = batch.input_ids[:, key_length - query_length : key_length] @@ -113,7 +131,7 @@ class BigcodeCausalLM(VectorizedCausalLM): attention_mask=batch.attention_mask[:, :key_length], position_ids=batch.position_ids[:, key_length - query_length : key_length], past_key_values=batch.past_key_values, - use_cache=True, + predict_all_tokens=batch.details, ) next_token_ids, logprobs = batch.next_token_chooser( input_ids, logits, batch.details @@ -126,10 +144,20 @@ class BigcodeCausalLM(VectorizedCausalLM): return next_token_ids, logprobs - def mock_kv_cache(self, batch: BigcodeBatch, dtype:Optional[torch.dtype]): - allocate_length=batch.max_input_length+-batch.max_input_length%8 - return [(torch.empty( - [len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], - dtype=dtype, - device=batch.input_ids.device, - ),batch.max_input_length-1) for _ in range(self.model.config.n_layer)] + def mock_kv_cache(self, batch: BigcodeBatch, dtype: Optional[torch.dtype]): + allocate_length = batch.max_input_length + -batch.max_input_length % 8 + return [ + ( + torch.empty( + [ + len(batch), + allocate_length - 1, + 2 * self.model.config.n_embd // self.model.config.n_head, + ], + dtype=dtype, + device=batch.input_ids.device, + ), + batch.max_input_length - 1, + ) + for _ in range(self.model.config.n_layer) + ] diff --git a/server/text_generation_server/models/gpt_bigcode2.py b/server/text_generation_server/models/gpt_bigcode2.py index 24875c3b..02670eac 100644 --- a/server/text_generation_server/models/gpt_bigcode2.py +++ b/server/text_generation_server/models/gpt_bigcode2.py @@ -4,10 +4,17 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer -from typing import Optional, Type +from typing import Optional, Type, List +from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase -from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM,VectorizedCausalLMBatch -from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import GPTBigCodeForCausalLM +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.vectorized_causal_lm import ( + VectorizedCausalLM, + VectorizedCausalLMBatch, +) +from text_generation_server.models.custom_modeling.gpt_bigcode2_modeling import ( + GPTBigCodeForCausalLM, +) tracer = trace.get_tracer(__name__) @@ -15,19 +22,40 @@ tracer = trace.get_tracer(__name__) @dataclass class Bigcode2Batch(VectorizedCausalLMBatch): kv_cache_seq_dim: int = 1 - pad_key_length_to_multiple:int=8 + pad_key_length_to_multiple: int = 8 # Prefill the attention mask for padded key length. - attention_mask_fill_value=False + attention_mask_fill_value = False + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ) -> "Bigcode2Batch": + batch = super().from_pb(pb, tokenizer, device) + batch.attention_mask[:, batch.max_input_length :].fill_(False) + return batch def _filter_kv_caches(self, keep_indices, sequence_slice): if self.past_key_values is not None: - for layer_kv, _ in self.past_key_values: + for layer_kv in self.past_key_values: # Update tensors in-place to allow incremental garbage collection layer_kv.data = layer_kv[keep_indices, sequence_slice] @classmethod - def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): + def concatenate(cls, batches: List["Bigcode2Batch"]) -> "Bigcode2Batch": + batch = super().concatenate(batches) + # Replace the attention mask with zeros to support padded key length. + # They are already filled with ones in super, but duplication is needed to generate the position ids. + batch.attention_mask[:, batch.max_input_length :].fill_(False) + return batch + + @classmethod + def _concatenate_key_values( + cls, batches, start_indices, end_indices, left_indices, max_input_length + ): device = batches[0].input_ids.device batch_size = sum([len(batch.requests) for batch in batches]) @@ -36,15 +64,19 @@ class Bigcode2Batch(VectorizedCausalLMBatch): raise ValueError("Only concatenate prefilled batches") past_key_values = [] - for kv_caches in zip(*(batch.past_key_values for batch in batches)): - key_values, seq_lengths = zip(*kv_caches) - assert all(left_index + seq_length == max_input_length for left_index, seq_length in zip(left_indices, seq_lengths)) - - allocate_seq_len=max(left_index + key_value.size(1) for left_index, key_value in zip(left_indices, key_values)) - allocate_seq_len += - allocate_seq_len % batches[0].pad_key_length_to_multiple + for key_values in zip(*(batch.past_key_values for batch in batches)): + allocate_seq_len = max( + left_index + key_value.size(1) + for left_index, key_value in zip(left_indices, key_values) + ) + allocate_seq_len += ( + -allocate_seq_len % batches[0].pad_key_length_to_multiple + ) kv_cache = torch.empty( - (batch_size, allocate_seq_len, *key_values[0].shape[2:]), dtype=key_values[0].dtype, device=device + (batch_size, allocate_seq_len, *key_values[0].shape[2:]), + dtype=key_values[0].dtype, + device=device, ) for key_value, start_index, end_index, left_index in zip( key_values, @@ -52,16 +84,21 @@ class Bigcode2Batch(VectorizedCausalLMBatch): end_indices, left_indices, ): - kv_cache[start_index:end_index,left_index:max_input_length].copy_(key_value) + kv_cache[start_index:end_index, left_index:max_input_length].copy_( + key_value + ) # Set padding to zero to avoid propagating nans. kv_cache[start_index:end_index, :left_index].fill_(0) kv_cache[start_index:end_index, max_input_length:].fill_(0) - past_key_values.append((kv_cache, max_input_length)) + past_key_values.append(kv_cache) def __len__(self): return len(self.requests) + class Bigcode2CausalLM(VectorizedCausalLM): + model: GPTBigCodeForCausalLM + def __init__( self, model_id: str, @@ -85,9 +122,12 @@ class Bigcode2CausalLM(VectorizedCausalLM): model_id, revision=revision, torch_dtype=dtype, + dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize == "bitsandbytes", ) + model.post_load_weights() + tokenizer.pad_token_id = ( model.config.pad_token_id if model.config.pad_token_id is not None @@ -106,29 +146,33 @@ class Bigcode2CausalLM(VectorizedCausalLM): def batch_type(self) -> Type[Bigcode2Batch]: return Bigcode2Batch - def forward(self, batch:Bigcode2Batch): + def forward(self, batch: Bigcode2Batch): key_length = batch.max_input_length if batch.past_key_values is None: # Prefill (flash attn, unpadded key length) - batch.pad_key_length_to_multiple=self.model.pad_key_length_to_multiple - padded_key_length=key_length - query_length=key_length + input_ids = batch.input_ids[:, :key_length] + logits, batch.past_key_values = self.model.prefill( + input_ids=input_ids, + attention_mask=batch.attention_mask[:, :key_length], + position_ids=batch.position_ids[:, :key_length], + predict_all_tokens=batch.details, + ) else: # Decode (fused attn, padded key length) - batch.attention_mask[:, key_length-1].fill_(True) - padded_key_length=key_length+-key_length%batch.pad_key_length_to_multiple - query_length=1 + batch.attention_mask[:, key_length - 1].fill_(True) + padded_key_length = ( + key_length + -key_length % batch.pad_key_length_to_multiple + ) + input_ids = batch.input_ids[:, key_length - 1 : key_length] + # Model Forward + logits, batch.past_key_values = self.model.decode( + input_ids=input_ids, + attention_mask=batch.attention_mask[:, :padded_key_length], + position_ids=batch.position_ids[:, key_length - 1 : key_length], + past_key_values=batch.past_key_values, + key_length=key_length, + ) - input_ids = batch.input_ids[:, key_length - query_length : key_length] - # Model Forward - logits, batch.past_key_values = self.model.forward( - input_ids=input_ids, - attention_mask=batch.attention_mask[:, :padded_key_length], - position_ids=batch.position_ids[:, key_length - query_length : key_length], - past_key_values=batch.past_key_values, - key_length=key_length, - predict_all_tokens=batch.details - ) next_token_ids, logprobs = batch.next_token_chooser( input_ids, logits, batch.details ) @@ -140,10 +184,20 @@ class Bigcode2CausalLM(VectorizedCausalLM): return next_token_ids, logprobs - def mock_kv_cache(self, batch: Bigcode2Batch, dtype:Optional[torch.dtype]): - allocate_length=batch.max_input_length+-batch.max_input_length%batch.pad_key_length_to_multiple - return [(torch.empty( - [len(batch), allocate_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], - dtype=dtype, - device=batch.input_ids.device, - ),batch.max_input_length-1) for _ in range(self.model.config.n_layer)] + def mock_kv_cache(self, batch: Bigcode2Batch, dtype: Optional[torch.dtype]): + allocate_length = ( + batch.max_input_length + + -batch.max_input_length % batch.pad_key_length_to_multiple + ) + return [ + torch.randn( + [ + len(batch), + allocate_length - 1, + 2 * self.model.config.n_embd // self.model.config.n_head, + ], + dtype=dtype, + device=batch.input_ids.device, + ) + for _ in range(self.model.config.n_layer) + ] diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index e3a5b61b..c128c0b2 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -58,9 +58,6 @@ class VectorizedCausalLMBatch(Batch): kv_cache_seq_dim: int = 2 - # Prefill the attention mask for the generated tokens - attention_mask_fill_value=True - # TODO: Get from requests (should these be lists?) details: bool = os.environ.get("RETURN_DETAILS") is not None generate_stream: bool = os.environ.get("GENERATE_STREAM") is not None @@ -116,7 +113,7 @@ class VectorizedCausalLMBatch(Batch): attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"]) - attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value) + attention_mask[:, max_input_length:].fill_(True) position_ids = attention_mask.cumsum(-1).sub_(1) position_ids[:, :max_input_length].relu_() @@ -271,7 +268,10 @@ class VectorizedCausalLMBatch(Batch): # Allocate maximum attention_mask attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) attention_mask[:, :max_input_length].fill_(0) - attention_mask[:, max_input_length:].fill_(cls.attention_mask_fill_value) + attention_mask[:, max_input_length:].fill_(True) + + position_ids = attention_mask.cumsum(-1).sub_(1) + position_ids[:, :max_input_length].relu_() input_ids = torch.empty(input_shape, dtype=torch.int64, device=device) # TODO : only needed for prefill @@ -287,16 +287,15 @@ class VectorizedCausalLMBatch(Batch): batch.input_ids[:, : batch.max_input_length] ) - position_ids = attention_mask.cumsum(-1).sub_(1) - position_ids[:, :max_input_length].relu_() - max_tokens = sum( batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches ) kv_cache_seq_dim = batches[0].kv_cache_seq_dim - past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices, max_input_length) + past_key_values = cls._concatenate_key_values( + batches, start_indices, end_indices, left_indices, max_input_length + ) return cls( batch_id=batches[0].batch_id, @@ -317,7 +316,9 @@ class VectorizedCausalLMBatch(Batch): ) @classmethod - def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices, max_input_length): + def _concatenate_key_values( + cls, batches, start_indices, end_indices, left_indices, max_input_length + ): device = batches[0].input_ids.device batch_size = sum([len(batch.requests) for batch in batches]) @@ -386,10 +387,10 @@ class VectorizedCausalLMBatch(Batch): return - def __len__(self): return len(self.requests) + class VectorizedCausalLM(Model): def __init__( self, @@ -441,7 +442,7 @@ class VectorizedCausalLM(Model): generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False ) - def forward(self, batch:VectorizedCausalLMBatch): + def forward(self, batch: VectorizedCausalLMBatch): key_length = batch.max_input_length query_length = key_length if batch.past_key_values is None else 1 input_ids = batch.input_ids[:, key_length - query_length : key_length] @@ -499,7 +500,7 @@ class VectorizedCausalLM(Model): prefill_token_ids, prefill_logprobs, batch.input_lengths ): # Input length has already been incremented so we subtract 1. - prefill_token_ids_ = prefill_token_ids_[-(input_length-1):] + prefill_token_ids_ = prefill_token_ids_[-(input_length - 1) :] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids_, clean_up_tokenization_spaces=False, @@ -558,23 +559,42 @@ class VectorizedCausalLM(Model): return generations, next_batch - def mock_kv_cache(self, batch: VectorizedCausalLMBatch, dtype:Optional[torch.dtype]): - from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM + def mock_kv_cache( + self, batch: VectorizedCausalLMBatch, dtype: Optional[torch.dtype] + ): + from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( + GPTBigCodeForCausalLM, + ) + if not isinstance(self.model, GPTBigCodeForCausalLM): raise NotImplementedError() - return [torch.empty( - [len(batch), batch.max_input_length-1, 2 * self.model.config.n_embd // self.model.config.n_head], - dtype=dtype, - device=batch.input_ids.device, - ) for _ in range(self.model.config.n_layer)] + return [ + torch.empty( + [ + len(batch), + batch.max_input_length - 1, + 2 * self.model.config.n_embd // self.model.config.n_head, + ], + dtype=dtype, + device=batch.input_ids.device, + ) + for _ in range(self.model.config.n_layer) + ] - def fast_forward(self, batch: VectorizedCausalLMBatch, max_input_length: int, cache_dtype:Optional[torch.dtype]): - diff=max_input_length-batch.max_input_length - batch.input_ids[:, batch.max_input_length:max_input_length].fill_(self.tokenizer.pad_token_id) + def fast_forward( + self, + batch: VectorizedCausalLMBatch, + max_input_length: int, + cache_dtype: Optional[torch.dtype], + ): + diff = max_input_length - batch.max_input_length + batch.input_ids[:, batch.max_input_length : max_input_length].fill_( + self.tokenizer.pad_token_id + ) batch.input_lengths = [length + diff for length in batch.input_lengths] batch.max_input_length += diff for stopping_criteria in batch.stopping_criterias: - stopping_criteria.current_tokens+=diff - batch.past_key_values = None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype) - - + stopping_criteria.current_tokens += diff + batch.past_key_values = ( + None if cache_dtype is None else self.mock_kv_cache(batch, cache_dtype) + )