From c84223590b0602f3c1b567764e81a577813f3f50 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 28 Feb 2024 11:02:39 +0100 Subject: [PATCH] add medusa --- .../models/custom_modeling/flash_starcoder2_modeling.py | 6 +++--- server/text_generation_server/models/flash_starcoder2.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 5ee09dd0..ed77af78 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, FastLayerNorm, @@ -486,13 +486,13 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): self.model = Starcoder2Model(config, weights) try: - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, ) except RuntimeError: - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="model.embed_tokens", weights=weights, diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 9b4ece11..2f6ae757 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -29,6 +29,7 @@ class FlashStarcoder2(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,6 +52,7 @@ class FlashStarcoder2(BaseFlashMistral): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa # Set context windows if config.sliding_window is not None: