From f41ab1278378530d3219fab2bb7358d1f8a06fc6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 29 Mar 2023 22:25:23 +0200 Subject: [PATCH] wip --- .../text_generation_server/models/__init__.py | 18 +- .../models/custom_modeling/__init__.py | 0 .../flash_neox_modeling.py | 0 .../flash_santacoder_modeling.py | 651 ++++++++++++++++++ .../models/flash_causal_lm.py | 458 ++++++++++++ .../models/flash_neox.py | 452 +----------- .../models/santacoder.py | 18 - 7 files changed, 1127 insertions(+), 470 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/__init__.py rename server/text_generation_server/models/{ => custom_modeling}/flash_neox_modeling.py (100%) create mode 100644 server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py create mode 100644 server/text_generation_server/models/flash_causal_lm.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 90c70cb5..b54b38fd 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -8,6 +8,7 @@ from typing import Optional from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import Galactica, GalacticaSharded @@ -18,17 +19,20 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1 + FLASH_ATTENTION = ( + torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1 + ) except ImportError: - if int(os.environ.get("FLASH_NEOX", 0)) == 1: - logger.exception("Could not import FlashNeoX") - FLASH_NEOX = False + if int(os.environ.get("FLASH_ATTENTION", 0)) == 1: + logger.exception("Could not import Flash Attention models") + FLASH_ATTENTION = False __all__ = [ "Model", "BLOOM", "BLOOMSharded", "CausalLM", + "FlashCausalLM", "Galactica", "GalacticaSharded", "GPTNeoxSharded", @@ -38,7 +42,7 @@ __all__ = [ "get_model", ] -if FLASH_NEOX: +if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) @@ -76,10 +80,10 @@ def get_model( if model_type == "gpt_neox": if sharded: - neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded + neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded return neox_cls(model_id, revision, quantize=quantize) else: - neox_cls = FlashNeoX if FLASH_NEOX else CausalLM + neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM return neox_cls(model_id, revision, quantize=quantize) if model_type == "t5": diff --git a/server/text_generation_server/models/custom_modeling/__init__.py b/server/text_generation_server/models/custom_modeling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py similarity index 100% rename from server/text_generation_server/models/flash_neox_modeling.py rename to server/text_generation_server/models/custom_modeling/flash_neox_modeling.py diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py new file mode 100644 index 00000000..f3e35c4c --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -0,0 +1,651 @@ +import torch +import torch.distributed + +from torch.nn import functional as F + +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.models.gpt_neox import GPTNeoXConfig + +# Flash attention imports +import rotary_emb +import flash_attn_cuda +import dropout_layer_norm + +from flash_attn.layers.rotary import RotaryEmbedding + + +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 6144: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + + +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv + + +class FlashNeoxAttention(torch.nn.Module): + def __init__( + self, + num_heads, + hidden_size, + rotary_pct, + rotary_emb_base, + process_group=None, + reduce=True, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + rotary_ndims = int(self.head_size * rotary_pct) + self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.query_key_value = FastLinear(hidden_size, 3 * hidden_size) + self.dense = FastLinear(hidden_size, hidden_size) + else: + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear( + hidden_size, + 3 * hidden_size, + process_group=process_group, + ) + self.dense = TensorParallelRowLinear( + hidden_size, hidden_size, process_group=process_group, reduce=reduce + ) + + def shuffle_qkv_dims(self): + """Swap dims to avoid an additional permute""" + self.query_key_value.weight = torch.nn.Parameter( + self.query_key_value.weight.view( + self.num_heads, 3, self.head_size, self.hidden_size + ) + .permute(1, 0, 2, 3) + .reshape(-1, self.hidden_size) + ) + self.query_key_value.bias = torch.nn.Parameter( + self.query_key_value.bias.view(self.num_heads, 3, self.head_size) + .permute(1, 0, 2) + .reshape(-1) + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(qkv[:, 0]) + # flash attention + flash_attn_cuda.fwd( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + query = qkv_rot[:, 0] + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + layer_past[:, 0], + layer_past[:, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class FlashMLP(nn.Module): + def __init__( + self, act, hidden_size, intermediate_size, process_group=None, reduce=True + ): + super().__init__() + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + if process_group is None: + self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size) + self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size) + else: + self.dense_h_to_4h = TensorParallelColumnLinear( + hidden_size, + intermediate_size, + process_group=process_group, + ) + self.dense_4h_to_h = TensorParallelRowLinear( + intermediate_size, + hidden_size, + process_group=process_group, + reduce=reduce, + ) + self.process_group = process_group + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class FlashNeoXLayer(nn.Module): + def __init__( + self, + num_heads, + act, + hidden_size, + intermediate_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + use_parallel_residual, + process_group=None, + ): + super().__init__() + self.use_parallel_residual = use_parallel_residual + self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.attention = FlashNeoxAttention( + num_heads, + hidden_size, + rotary_pct, + rotary_emb_base, + process_group, + reduce=not use_parallel_residual, + ) + self.mlp = FlashMLP( + act, + hidden_size, + intermediate_size, + process_group, + reduce=not use_parallel_residual, + ) + self.process_group = process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + if self.use_parallel_residual: + ln1_hidden_states, _ = self.input_layernorm(hidden_states) + + attn_output = self.attention( + ln1_hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(ln2_hidden_states) + intermediate = mlp_output + attn_output + + # Only reduce once and after the addition instead of once per layer + if self.process_group is not None: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate + hidden_states, None + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.attention( + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashGPTNeoXPreTrainedModel(PreTrainedModel): + config_class = GPTNeoXConfig + base_model_prefix = "gpt_neox" + supports_gradient_checkpointing = False + _no_split_modules = None + + +class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): + def __init__(self, config, process_group=None): + super().__init__(config) + self.config = config + + self.tp_embeddings = False + if process_group is not None: + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True + + if self.tp_embeddings: + self.embed_in = TensorParallelEmbedding( + config.vocab_size, config.hidden_size, process_group=process_group + ) + else: + self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layers = nn.ModuleList( + [ + FlashNeoXLayer( + config.num_attention_heads, + config.hidden_act, + config.hidden_size, + config.intermediate_size, + config.rotary_pct, + config.rotary_emb_base, + config.layer_norm_eps, + config.use_parallel_residual, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = FastLayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].attention.head_size + self.num_heads = self.layers[0].attention.num_heads + + def post_load_weights(self): + if isinstance(self.embed_in, TensorParallelEmbedding): + self.embed_in.add_null_idx() + for layer in self.layers: + layer: FlashNeoXLayer + layer.attention.shuffle_qkv_dims() + layer.attention.query_key_value.transpose_weight() + layer.attention.dense.transpose_weight() + layer.mlp.dense_h_to_4h.transpose_weight() + layer.mlp.dense_4h_to_h.transpose_weight() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + model = super(FlashGPTNeoXModel, cls).from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + model.post_load_weights() + return model + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states = self.embed_in(input_ids) + + # Prefill + if past_key_values is None: + # Create past tensor + past_key_values = hidden_states.new_empty( + ( + len(self.layers), + len(hidden_states), + 2, + self.num_heads, + self.head_size, + ) + ) + layer_past_present_indices = None + cu_seqlens_q = None + # Decode + else: + # Create indices from cumulative sequence lengths + layer_past_present_indices = cu_seqlens[1:] - 1 + cu_seqlens_q = torch.arange( + cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device + ) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + past_key_values[i], + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, _ = self.final_layer_norm(hidden_states, residual) + + return hidden_states, past_key_values + + +class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if config.tp_parallel: + process_group = torch.distributed.distributed_c10d._get_default_group() + else: + process_group = None + + self.gpt_neox = FlashGPTNeoXModel(config, process_group) + + if self.gpt_neox.tp_embeddings: + self.embed_out = FastLinear( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.embed_out = FastLinear( + config.hidden_size, config.vocab_size, bias=False + ) + + def post_load_weights(self): + self.gpt_neox.post_load_weights() + self.embed_out.transpose_weight() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + model.post_load_weights() + return model + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states, present = self.gpt_neox( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) + return self.embed_out(hidden_states), present diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py new file mode 100644 index 00000000..e1a10cbf --- /dev/null +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -0,0 +1,458 @@ +import torch +import torch.distributed + +from torch.nn import functional as F + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel +from typing import Optional, Tuple, List, Type, Union + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, + Sampling, +) + +tracer = trace.get_tracer(__name__) + + +@dataclass +class FlashCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + + # Decoder values + input_ids: torch.Tensor + position_ids: torch.Tensor + # cumulative sequence lengths + cu_seqlens: torch.Tensor + max_seqlen: int + past_key_values: Optional[torch.Tensor] + + # All tokens + all_input_ids: List[List[int]] + all_input_ids_tensor: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + def to_pb(self) -> generate_pb2.Batch: + return generate_pb2.Batch( + id=self.batch_id, requests=self.requests, size=len(self) + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ) -> "CausalLMBatch": + input_ids = [] + position_ids = [] + cu_seqlens = [0] + max_seqlen = 0 + + input_lengths = [] + all_input_ids = [] + all_input_ids_tensor = [] + + next_token_choosers = [] + stopping_criterias = [] + + # Cumulative length + cumulative_length = 0 + + # Parse batch + for r in pb.requests: + tokenized_input = tokenizer(r.inputs)["input_ids"] + input_length = len(tokenized_input) + max_seqlen = max(max_seqlen, input_length) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + + tokenized_input = torch.tensor(tokenized_input, device=device) + input_ids.append(tokenized_input) + + # Position ids + position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(cumulative_length + input_length) + + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + all_input_ids_tensor.append( + F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) + ) + + # Update + cumulative_length += input_length + + input_ids = torch.concat(input_ids) + position_ids = torch.concat(position_ids) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + + return cls( + batch_id=pb.id, + requests=pb.requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=None, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + # Batch attributes + requests = [] + input_lengths = [] + all_input_ids = [] + all_input_ids_tensor = [] + next_token_choosers = [] + stopping_criterias = [] + + # Batch tensors + input_ids = [] + position_ids = [] + cu_seqlens = [torch.tensor([0], dtype=torch.int32)] + max_seqlen = 0 + past_key_values = [] + + # Cumulative length + cumulative_length = torch.tensor(0) + + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + all_input_ids.extend(batch.all_input_ids) + all_input_ids_tensor.extend(batch.all_input_ids_tensor) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length) + + input_ids.append(batch.input_ids) + position_ids.append(batch.position_ids) + past_key_values.append(batch.past_key_values) + + max_seqlen = max(max_seqlen, batch.max_seqlen) + + # Update + cumulative_length += batch.cu_seqlens[-1] + + input_ids = torch.concat(input_ids) + position_ids = torch.concat(position_ids) + # Concat on dim=1 as first dim represents the model layers + past_key_values = torch.concat(past_key_values, dim=1) + cu_seqlens = torch.concat(cu_seqlens) + + return FlashCausalLMBatch( + batch_id=batches[0].batch_id, + requests=requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + + def __len__(self): + return len(self.requests) + + +class FlashCausalLM(Model): + def __init__( + self, + model_cls: Type[PreTrainedModel], + model_id: str, + revision: Optional[str] = None, + quantize=False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashCausalLM is only available on GPU") + + if quantize: + raise NotImplementedError("FlashCausalLM does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + self.model = ( + model_cls.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + ) + .eval() + .cuda() + ) + + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @property + def batch_type(self) -> Type[FlashCausalLMBatch]: + return FlashCausalLMBatch + + def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: int, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Model Forward + return self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_s=max_s, + past_key_values=past_key_values, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: FlashCausalLMBatch + ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: + # Better to send to device here to avoid device issues in concatenate + position_ids = batch.position_ids.to(self.device, non_blocking=True) + cu_seqlens = batch.cu_seqlens.to(self.device) + + out, present = self.forward( + batch.input_ids, + position_ids, + cu_seqlens, + batch.max_seqlen, + batch.past_key_values, + ) + + # List of indices to cache + next_batch_keep_indices = [] + + # New values for next forward + next_batch_input_ids = [] + next_batch_position_ids = [] + next_batch_cu_seqlens = [0] + next_batch_max_seqlen = 0 + next_batch_past_key_values = [] + next_batch_input_lengths = [] + next_batch_all_input_ids = [] + next_batch_all_input_ids_tensor = [] + + # Cumulative length + cumulative_length = 0 + + # Results + generations: List[Generation] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.all_input_ids_tensor, + ) + + # For each member of the batch + for i, ( + request, + input_length, + next_token_chooser, + stopping_criteria, + all_input_ids, + all_input_ids_tensor, + ) in enumerate(iterator): + # Indexing metadata + start_index = cumulative_length + end_index = cumulative_length + input_length + + if batch.past_key_values is None: + # Prefill mode + # out is of shape [cumulative_sequence_lengths, vocab_size] + logits = out[start_index:end_index] + else: + # Decode mode + # out is of shape [batch_size, vocab_size] + logits = out[i].unsqueeze(0) + + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids_tensor[None, :input_length], logits + ) + next_token_id_squeezed = next_token_id.squeeze() + next_token_id_item = next_token_id_squeezed.item() + + # Append next token to all tokens + all_input_ids.append(next_token_id_item) + all_input_ids_tensor[input_length] = next_token_id_item + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id_item] + next_token_text = self.decode_token( + next_token_id_item, + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_item, + next_token_text, + ) + + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + # Keep request in the batch + next_batch_keep_indices.append(i) + generated_text = None + + # Get sequence present + seq_present = present[:, start_index:end_index] + # Pad it for next iter attention + past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) + next_batch_past_key_values.append(past) + + next_batch_input_ids.append(next_token_id) + next_batch_position_ids.append(input_length) + # Cumulative sum + next_batch_cu_seqlens.append( + next_batch_cu_seqlens[-1] + new_input_length + ) + next_batch_input_lengths.append(new_input_length) + next_batch_all_input_ids.append(all_input_ids) + next_batch_all_input_ids_tensor.append(all_input_ids_tensor) + next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) + + # Prefill + if stopping_criteria.current_tokens == 1: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + logprobs.gather( + 1, all_input_ids_tensor[1:input_length].unsqueeze(1) + ).squeeze(1)[:-1].tolist() + prefill_token_ids = all_input_ids[:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_item, + next_token_logprob, + next_token_text, + next_token_id_item in self.all_special_ids, + generated_text, + ) + + generations.append(generation) + cumulative_length += input_length + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generations, None + + # If we finished at least one generation, we need to evict the indices of the generations that finished + # from the values of the next batch + if len(next_batch_keep_indices) != len(batch): + # Apply indices to requests, token_choosers and stopping_criterias that need to be cached + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Create final next batch tensors + next_batch_position_ids = torch.tensor( + next_batch_position_ids, dtype=torch.int32 + ) + next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) + if len(next_batch_keep_indices) > 1: + next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1) + next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) + else: + next_batch_input_ids = next_batch_input_ids[0].view(1) + next_batch_past_key_values = next_batch_past_key_values[0] + + next_batch = FlashCausalLMBatch( + batch_id=batch.batch_id, + requests=next_batch_requests, + input_ids=next_batch_input_ids, + position_ids=next_batch_position_ids, + cu_seqlens=next_batch_cu_seqlens, + max_seqlen=next_batch_max_seqlen, + past_key_values=next_batch_past_key_values, + input_lengths=next_batch_input_lengths, + all_input_ids=next_batch_all_input_ids, + all_input_ids_tensor=next_batch_all_input_ids_tensor, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + ) + return generations, next_batch diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index b97f342a..e415a725 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -1,33 +1,20 @@ import torch import torch.distributed -from torch.nn import functional as F - from accelerate import init_empty_weights -from dataclasses import dataclass from opentelemetry import trace from safetensors import safe_open -from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig -from typing import Optional, Tuple, List, Type, Union +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, Tuple, List -from text_generation_server.models import Model -from text_generation_server.models.flash_neox_modeling import ( +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelColumnLinear, ) -from text_generation_server.models.types import ( - Batch, - PrefillTokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( - NextTokenChooser, - StoppingCriteria, - Sampling, initialize_torch_distributed, weight_files, ) @@ -35,437 +22,12 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -@dataclass -class FlashNeoXBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - - # Decoder values - input_ids: torch.Tensor - position_ids: torch.Tensor - # cumulative sequence lengths - cu_seqlens: torch.Tensor - max_seqlen: int - past_key_values: Optional[torch.Tensor] - - # All tokens - all_input_ids: List[List[int]] - all_input_ids_tensor: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - def to_pb(self) -> generate_pb2.Batch: - return generate_pb2.Batch( - id=self.batch_id, requests=self.requests, size=len(self) - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, - ) -> "CausalLMBatch": - input_ids = [] - position_ids = [] - cu_seqlens = [0] - max_seqlen = 0 - - input_lengths = [] - all_input_ids = [] - all_input_ids_tensor = [] - - next_token_choosers = [] - stopping_criterias = [] - - # Cumulative length - cumulative_length = 0 - - # Parse batch - for r in pb.requests: - tokenized_input = tokenizer(r.inputs)["input_ids"] - input_length = len(tokenized_input) - max_seqlen = max(max_seqlen, input_length) - input_lengths.append(input_length) - all_input_ids.append(tokenized_input) - - tokenized_input = torch.tensor(tokenized_input, device=device) - input_ids.append(tokenized_input) - - # Position ids - position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) - - # Add cumulative lengths of all previous inputs - cu_seqlens.append(cumulative_length + input_length) - - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - all_input_ids_tensor.append( - F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) - ) - - # Update - cumulative_length += input_length - - input_ids = torch.concat(input_ids) - position_ids = torch.concat(position_ids) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) - - return cls( - batch_id=pb.id, - requests=pb.requests, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=None, - input_lengths=input_lengths, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - ) - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": - # Batch attributes - requests = [] - input_lengths = [] - all_input_ids = [] - all_input_ids_tensor = [] - next_token_choosers = [] - stopping_criterias = [] - - # Batch tensors - input_ids = [] - position_ids = [] - cu_seqlens = [torch.tensor([0], dtype=torch.int32)] - max_seqlen = 0 - past_key_values = [] - - # Cumulative length - cumulative_length = torch.tensor(0) - - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - all_input_ids.extend(batch.all_input_ids) - all_input_ids_tensor.extend(batch.all_input_ids_tensor) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - # Add cumulative lengths of all previous inputs - cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length) - - input_ids.append(batch.input_ids) - position_ids.append(batch.position_ids) - past_key_values.append(batch.past_key_values) - - max_seqlen = max(max_seqlen, batch.max_seqlen) - - # Update - cumulative_length += batch.cu_seqlens[-1] - - input_ids = torch.concat(input_ids) - position_ids = torch.concat(position_ids) - # Concat on dim=1 as first dim represents the model layers - past_key_values = torch.concat(past_key_values, dim=1) - cu_seqlens = torch.concat(cu_seqlens) - - return FlashNeoXBatch( - batch_id=batches[0].batch_id, - requests=requests, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - input_lengths=input_lengths, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - ) - - def __len__(self): - return len(self.requests) - - -class FlashNeoX(Model): +class FlashNeoX(FlashCausalLM): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - else: - raise NotImplementedError("FlashNeoX is only available on GPU") - - if quantize: - raise NotImplementedError("FlashNeoX does not support quantization") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" - ) - self.model = ( - FlashGPTNeoXForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - ) - .eval() - .cuda() - ) - tokenizer.pad_token_id = ( - self.model.config.pad_token_id - if self.model.config.pad_token_id is not None - else self.model.config.eos_token_id - ) - super(FlashNeoX, self).__init__( - tokenizer=tokenizer, - device=device, + FlashGPTNeoXForCausalLM, model_id, revision, quantize ) - @property - def batch_type(self) -> Type[FlashNeoXBatch]: - return FlashNeoXBatch - - def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False - ) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - max_s: int, - past_key_values: Optional = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Model Forward - return self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_s=max_s, - past_key_values=past_key_values, - ) - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batch: FlashNeoXBatch - ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: - # Better to send to device here to avoid device issues in concatenate - position_ids = batch.position_ids.to(self.device, non_blocking=True) - cu_seqlens = batch.cu_seqlens.to(self.device) - - out, present = self.forward( - batch.input_ids, - position_ids, - cu_seqlens, - batch.max_seqlen, - batch.past_key_values, - ) - - # List of indices to cache - next_batch_keep_indices = [] - - # New values for next forward - next_batch_input_ids = [] - next_batch_position_ids = [] - next_batch_cu_seqlens = [0] - next_batch_max_seqlen = 0 - next_batch_past_key_values = [] - next_batch_input_lengths = [] - next_batch_all_input_ids = [] - next_batch_all_input_ids_tensor = [] - - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - batch.all_input_ids_tensor, - ) - - # For each member of the batch - for i, ( - request, - input_length, - next_token_chooser, - stopping_criteria, - all_input_ids, - all_input_ids_tensor, - ) in enumerate(iterator): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - if batch.past_key_values is None: - # Prefill mode - # out is of shape [cumulative_sequence_lengths, vocab_size] - logits = out[start_index:end_index] - else: - # Decode mode - # out is of shape [batch_size, vocab_size] - logits = out[i].unsqueeze(0) - - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids_tensor[None, :input_length], logits - ) - next_token_id_squeezed = next_token_id.squeeze() - next_token_id_item = next_token_id_squeezed.item() - - # Append next token to all tokens - all_input_ids.append(next_token_id_item) - all_input_ids_tensor[input_length] = next_token_id_item - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id_item] - next_token_text = self.decode_token( - next_token_id_item, - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_item, - next_token_text, - ) - - if stop: - # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - # Keep request in the batch - next_batch_keep_indices.append(i) - generated_text = None - - # Get sequence present - seq_present = present[:, start_index:end_index] - # Pad it for next iter attention - past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) - next_batch_past_key_values.append(past) - - next_batch_input_ids.append(next_token_id) - next_batch_position_ids.append(input_length) - # Cumulative sum - next_batch_cu_seqlens.append( - next_batch_cu_seqlens[-1] + new_input_length - ) - next_batch_input_lengths.append(new_input_length) - next_batch_all_input_ids.append(all_input_ids) - next_batch_all_input_ids_tensor.append(all_input_ids_tensor) - next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) - - # Prefill - if stopping_criteria.current_tokens == 1: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids_tensor[1:input_length].unsqueeze(1) - ).squeeze(1)[:-1].tolist() - prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - next_token_id_item, - next_token_logprob, - next_token_text, - next_token_id_item in self.all_special_ids, - generated_text, - ) - - generations.append(generation) - cumulative_length += input_length - - # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: - return generations, None - - # If we finished at least one generation, we need to evict the indices of the generations that finished - # from the values of the next batch - if len(next_batch_keep_indices) != len(batch): - # Apply indices to requests, token_choosers and stopping_criterias that need to be cached - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias - - # Create final next batch tensors - next_batch_position_ids = torch.tensor( - next_batch_position_ids, dtype=torch.int32 - ) - next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) - if len(next_batch_keep_indices) > 1: - next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1) - next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) - else: - next_batch_input_ids = next_batch_input_ids[0].view(1) - next_batch_past_key_values = next_batch_past_key_values[0] - - next_batch = FlashNeoXBatch( - batch_id=batch.batch_id, - requests=next_batch_requests, - input_ids=next_batch_input_ids, - position_ids=next_batch_position_ids, - cu_seqlens=next_batch_cu_seqlens, - max_seqlen=next_batch_max_seqlen, - past_key_values=next_batch_past_key_values, - input_lengths=next_batch_input_lengths, - all_input_ids=next_batch_all_input_ids, - all_input_ids_tensor=next_batch_all_input_ids_tensor, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - ) - return generations, next_batch - class FlashNeoXSharded(FlashNeoX): def __init__( @@ -508,7 +70,7 @@ class FlashNeoXSharded(FlashNeoX): model.post_load_weights() self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) - super(FlashNeoX, self).__init__( + super(FlashCausalLM, self).__init__( tokenizer=tokenizer, device=device, ) diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index fe15cde0..b5190b6d 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -6,12 +6,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from text_generation_server.models import CausalLM -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_SUFFIX = "" -FIM_PAD = "" -EOD = "<|endoftext|>" - class SantaCoder(CausalLM): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): @@ -28,18 +22,6 @@ class SantaCoder(CausalLM): tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left" ) - tokenizer.add_special_tokens( - { - "additional_special_tokens": [ - EOD, - FIM_PREFIX, - FIM_MIDDLE, - FIM_SUFFIX, - FIM_PAD, - ], - "pad_token": EOD, - } - ) self.model = ( AutoModelForCausalLM.from_pretrained(