From ac419f5e467b25eb9b532cfdc2d7cd23d2b28019 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Feb 2024 11:37:05 +0000 Subject: [PATCH] Upgrade ALL the code. --- server/text_generation_server/models/__init__.py | 16 ++++------------ server/text_generation_server/models/bloom.py | 2 ++ .../text_generation_server/models/causal_lm.py | 1 + .../models/custom_modeling/bloom_modeling.py | 4 ++-- .../custom_modeling/flash_gemma_modeling.py | 4 ++-- .../custom_modeling/flash_llama_modeling.py | 4 ++-- .../custom_modeling/flash_mixtral_modeling.py | 4 ++-- .../custom_modeling/flash_neox_modeling.py | 4 ++-- .../models/custom_modeling/flash_phi_modeling.py | 4 ++-- .../models/custom_modeling/flash_rw_modeling.py | 4 ++-- .../custom_modeling/flash_santacoder_modeling.py | 4 ++-- .../models/custom_modeling/idefics_modeling.py | 4 ++-- .../models/custom_modeling/mamba_modeling.py | 11 ++++++----- .../models/custom_modeling/mpt_modeling.py | 4 ++-- .../models/custom_modeling/neox_modeling.py | 4 ++-- .../models/custom_modeling/opt_modeling.py | 4 ++-- .../models/custom_modeling/phi_modeling.py | 4 ++-- .../models/custom_modeling/t5_modeling.py | 12 ++++++------ .../models/flash_causal_lm.py | 11 +++++++---- .../text_generation_server/models/flash_gemma.py | 3 ++- .../text_generation_server/models/flash_llama.py | 3 ++- .../models/flash_mistral.py | 5 ++--- .../models/flash_mixtral.py | 2 ++ .../text_generation_server/models/flash_neox.py | 2 ++ .../text_generation_server/models/flash_phi.py | 3 ++- server/text_generation_server/models/flash_rw.py | 2 ++ .../models/flash_santacoder.py | 2 ++ server/text_generation_server/models/idefics.py | 2 ++ server/text_generation_server/models/mamba.py | 12 ++++++++---- server/text_generation_server/models/mpt.py | 2 ++ server/text_generation_server/models/opt.py | 2 ++ server/text_generation_server/models/phi.py | 2 ++ .../text_generation_server/models/santacoder.py | 1 + .../text_generation_server/models/seq2seq_lm.py | 7 +++++-- server/text_generation_server/models/t5.py | 5 ++++- 35 files changed, 94 insertions(+), 66 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcdc25f1..dedbb7e2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -115,16 +115,6 @@ def get_model( else: set_speculate(0) - if "facebook/galactica" in model_id: - return GalacticaSharded( - model_id, - revision, - quantize=quantize, - use_medusa=use_medusa, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) @@ -177,7 +167,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "gpt_bigcode": + if model_type in {"gpt_bigcode", "gpt2"}: if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, @@ -311,9 +301,9 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) elif sharded: raise NotImplementedError( @@ -324,6 +314,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -448,6 +439,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index fed5e6f3..590c0d57 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM): ) config.pad_token_id = 3 config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a0f0c9e8..b7f5e9db 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -482,6 +482,7 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 5423d75a..8e3be63f 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -36,7 +36,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) CUSTOM_KERNELS_ENABLED = False @@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel): super().__init__(config) self.transformer = BloomModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="word_embeddings", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4a08bc2a..d7bedf72 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -37,7 +37,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): super().__init__() self.model = FlashGemmaModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 1626eb4d..88b3d9d2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): super().__init__() self.model = FlashLlamaModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3d3caba3..17d4f708 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,7 +37,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, ) @@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): super().__init__() self.model = MixtralModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 780861c2..ee062d3d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -33,7 +33,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLayerNorm, PositionRotaryEmbedding, get_linear, @@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): super().__init__(config) self.gpt_neox = FlashGPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( + self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9a929e9..cfe447a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -12,7 +12,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastLayerNorm, ) @@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module): super().__init__() self.model = FlashPhiModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6a530f3c..0d8e74b1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,7 +12,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLayerNorm, PositionRotaryEmbedding, get_linear, @@ -613,7 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): self.transformer = FlashRWModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index d3fe95d0..bbb603a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, - TensorParallelHead, + SpeculativeHead, TensorParallelEmbedding, FastLayerNorm, get_linear, @@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module): def __init__(self, config, weights): super().__init__() self.transformer = FlashSantacoderModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 4f7dfb95..dcdc213e 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -51,7 +51,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, PositionRotaryEmbedding, FastLinear, ) @@ -272,7 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): weights, ) -> None: super().__init__() - self.fc = TensorParallelHead.load( + self.fc = SpeculativeHead.load( config=config, prefix="lm_head", weights=weights ) self.additional_fc = FastLinear.load( diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index baf1fb85..08bbb3a2 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig import torch.nn.functional as F from text_generation_server.utils.layers import ( + SpeculativeHead, TensorParallelEmbedding, FastRMSNorm, FastLinear, @@ -205,14 +206,14 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = FastLinear.load( - config, f"{prefix}.embedding", weights, bias=False + self.lm_head = SpeculativeHead.load( + config, f"{prefix}.embedding", weights ) self.config = config def forward( self, input_ids: torch.Tensor, inference_params=None, residual=None - ) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.embed_tokens(input_ids) for i, block in enumerate(self.blocks): hidden_states, residual, conv_state, ssm_state = block( @@ -226,8 +227,8 @@ class MambaModel(nn.Module): ) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states = hidden_states.view(residual.shape) - logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) # update the offset for the next inference using these params inference_params.seqlen_offset += input_ids.size(1) - return logits + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 2e2e423e..0e755e47 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -21,7 +21,7 @@ from text_generation_server.utils.layers import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, get_linear, ) @@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel): if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = MPTModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) self.logit_scale = None diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index dbcefbae..2550d2d1 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -44,7 +44,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) @@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): def __init__(self, config, weights): super().__init__(config) self.gpt_neox = GPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( + self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index ce3f5e21..de5e95af 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) EPS = 1e-5 @@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel): self.model = OPTModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="model.decoder.embed_tokens", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index e5c09728..1571f9fd 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -13,7 +13,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLinear, ) @@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module): weights=weights, eps=config.layer_norm_epsilon, ) - self.linear = TensorParallelHead.load( + self.linear = SpeculativeHead.load( config=config, prefix="lm_head.linear", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index d3e4f53a..be956008 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -42,7 +42,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) @@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ) try: - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights ) except RuntimeError: # Some models like t5-small were saved with shared weights unlike flan # Since they are declared as the same arch we have no choice but hope # that this is OK instead of using a proper flag. - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="shared", weights=weights ) @@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) - lm_logits = self.lm_head(sequence_output) + logits, speculative_logits = self.lm_head(sequence_output) loss = None if labels is not None: @@ -1142,7 +1142,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): return Seq2SeqLMOutput( loss=loss, - logits=lm_logits, + logits=logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, @@ -1150,7 +1150,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, - ) + ), speculative_logits def prepare_inputs_for_generation( self, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a63c6641..3d3520c7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -723,7 +723,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - self.cuda_graphs[bs]["logits"] = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, @@ -734,6 +734,8 @@ class FlashCausalLM(Model): max_s=max_s, lm_head_indices=None, ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): @@ -805,7 +807,7 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) - def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: + def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -900,9 +902,10 @@ class FlashCausalLM(Model): # Replay the graph cuda_graph["graph"].replay() - # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 220b3992..beb12371 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 94bd58f4..8c2c1086 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 69d7493a..2d67ec4d 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -409,8 +409,7 @@ class BaseFlashMistral(FlashCausalLM): lm_head_indices=None, ) self.cuda_graphs[bs]["logits"] = logits - if speculative_logits is not None: - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -516,7 +515,7 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph["graph"].replay() # Slice output to the correct shape - speculative_logits = cuda_graph["speculative_logits"][:bs] if "speculative_logits" in cuda_graph else None + speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None logits = cuda_graph["logits"][:bs] return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 6f77a658..2ee35e82 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 80f8804d..5a351bd7 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 061b9740..cb55f9e6 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index dfab8888..fc1e26bd 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM): ) config.quantize = quantize + config.use_medusa = use_medusa if config.quantize == "gptq": weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 22171ec0..034949f9 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM): trust_remote_code=True, ) config.quantize = quantize + config.use_medusa = use_medusa config.transpose = config.architectures[0].startswith("GPT2") torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index fa23d1f9..baa1945b 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa config.vision_config.quantize = quantize tokenizer = LlamaTokenizerFast.from_pretrained( diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9d59f424..c60d5ab0 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -408,6 +408,7 @@ class Mamba(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -444,6 +445,7 @@ class Mamba(Model): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) @@ -505,7 +507,7 @@ class Mamba(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - logits = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, inference_params=inference_params ) torch.cuda.synchronize() @@ -514,6 +516,7 @@ class Mamba(Model): "inference_params": inference_params, "graph": graph, "logits": logits, + "speculative_logits": speculative_logits, } self.cuda_graphs[batch_size] = graph_dict @@ -556,9 +559,10 @@ class Mamba(Model): inference_params.ssm_states.copy_( cuda_graph["inference_params"].ssm_states[:, :bs] ) - # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() @@ -589,7 +593,7 @@ class Mamba(Model): batch.inference_params = inference_params # Forward pass - logits = self.forward(input_ids, inference_params=batch.inference_params) + logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params) # batch.inference_params = new_inference_params # Results diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index e419467f..6b3f29a6 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -43,6 +43,7 @@ class MPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -75,6 +76,7 @@ class MPTSharded(CausalLM): config = json.load(f) config = PretrainedConfig(**config) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 58fb212f..703e5b58 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,6 +22,7 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -47,6 +48,7 @@ class OPTSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index 79aa3fb9..cc4e2505 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -22,6 +22,7 @@ class Phi(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -52,6 +53,7 @@ class Phi(CausalLM): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 7b269d8e..73c21cce 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,6 +19,7 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 777a55ba..cae5525c 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -532,6 +532,7 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -596,11 +597,12 @@ class Seq2SeqLM(Model): past_key_values: Optional = None, ) -> Tuple[ torch.Tensor, + Optional[torch.Tensor], torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: # Model Forward - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -611,6 +613,7 @@ class Seq2SeqLM(Model): ) return ( outputs.logits, + speculative_logits, outputs.encoder_last_hidden_state, outputs.past_key_values, ) @@ -635,7 +638,7 @@ class Seq2SeqLM(Model): else: encoder_last_hidden_state = None - logits, encoder_last_hidden_state, past = self.forward( + logits, speculative_logits, encoder_last_hidden_state, past = self.forward( batch.input_ids, batch.attention_mask, batch.decoder_input_ids, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 161e69ba..3f3cb965 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa tokenizer = AutoTokenizer.from_pretrained( model_id, @@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM): List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: # Model Forward - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM): return ( outputs.logits, + speculative_logits, outputs.encoder_last_hidden_state, outputs.past_key_values, )