From 64d38afa9f8b9f944f783a5e6e0b1cce75e8573b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Feb 2024 13:01:43 +0000 Subject: [PATCH] Black. --- .../models/causal_lm.py | 4 ++- .../models/custom_modeling/bloom_modeling.py | 17 +++++++------ .../custom_modeling/flash_rw_modeling.py | 4 +-- .../custom_modeling/idefics_modeling.py | 4 +-- .../models/custom_modeling/mamba_modeling.py | 4 +-- .../models/custom_modeling/t5_modeling.py | 25 +++++++++++-------- .../models/flash_causal_lm.py | 11 +++++--- .../models/flash_mistral.py | 11 +++++--- server/text_generation_server/models/mamba.py | 10 ++++++-- server/text_generation_server/utils/layers.py | 8 +++++- 10 files changed, 61 insertions(+), 37 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index d4666229..bbcef210 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -551,7 +551,9 @@ class CausalLM(Model): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]] + ]: # Model Forward kwargs = { "input_ids": input_ids, diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 9381d164..10b40483 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -911,10 +911,13 @@ class BloomForCausalLM(BloomPreTrainedModel): output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ), speculative_logits + return ( + CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ), + speculative_logits, + ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 0d8e74b1..a9127d1f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): self.transformer = FlashRWModel(config, weights) - self.lm_head = SpeculativeHead.load( - config, prefix="lm_head", weights=weights - ) + self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index dcdc213e..7d7bf23d 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): weights, ) -> None: super().__init__() - self.fc = SpeculativeHead.load( - config=config, prefix="lm_head", weights=weights - ) + self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights) self.additional_fc = FastLinear.load( config=config, prefix="lm_head.additional_fc", diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 08bbb3a2..c58a617f 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -206,9 +206,7 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = SpeculativeHead.load( - config, f"{prefix}.embedding", weights - ) + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) self.config = config def forward( diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index be956008..2773fb15 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1140,17 +1140,20 @@ class T5ForConditionalGeneration(T5PreTrainedModel): output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ), speculative_logits + return ( + Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ), + speculative_logits, + ) def prepare_inputs_for_generation( self, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3d3520c7..1276fefa 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -807,7 +807,9 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) - def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward( + self, batch: FlashCausalLMBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -903,7 +905,11 @@ class FlashCausalLM(Model): # Replay the graph cuda_graph["graph"].replay() # Slice output to the correct shape - speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) logits = cuda_graph["logits"][:bs] return logits, speculative_logits @@ -961,7 +967,6 @@ class FlashCausalLM(Model): batch.speculative_ids, speculative_logits, ) - logger.info(f"Accepted ids {accepted_ids}") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 2d67ec4d..d3c0da9c 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -412,7 +412,9 @@ class BaseFlashMistral(FlashCausalLM): self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() - def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward( + self, batch: FlashMistralBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -515,12 +517,15 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph["graph"].replay() # Slice output to the correct shape - speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) logits = cuda_graph["logits"][:bs] return logits, speculative_logits - class FlashMistral(BaseFlashMistral): def __init__( self, diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index c60d5ab0..2500d454 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -560,7 +560,11 @@ class Mamba(Model): cuda_graph["inference_params"].ssm_states[:, :bs] ) # Slice output to the correct shape - speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) logits = cuda_graph["logits"][:bs] return logits, speculative_logits @@ -593,7 +597,9 @@ class Mamba(Model): batch.inference_params = inference_params # Forward pass - logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params) + logits, speculative_logits = self.forward( + input_ids, inference_params=batch.inference_params + ) # batch.inference_params = new_inference_params # Results diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c707b06d..d923ebfc 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -379,6 +379,7 @@ class SuperLayer(nn.Module): def forward(self, x): return self.linear.forward(x) + class ResBlock(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() @@ -426,6 +427,7 @@ class MedusaHead(torch.nn.Module): x = self.out(x) return x + class SpeculativeHead(nn.Module): def __init__(self, lm_head, medusa): super().__init__() @@ -440,6 +442,7 @@ class SpeculativeHead(nn.Module): from pathlib import Path from safetensors import safe_open import json + medusa_config = str(Path(use_medusa) / "config.json") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") @@ -460,11 +463,14 @@ class SpeculativeHead(nn.Module): medusa = None return SpeculativeHead(lm_head, medusa) - def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: logits = self.lm_head(input) speculative_logits = self.medusa(input) if self.medusa is not None else None return logits, speculative_logits + class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear)