add medusa

This commit is contained in:
OlivierDehaene 2024-02-28 11:02:39 +01:00
parent a56bd736e6
commit c84223590b
2 changed files with 5 additions and 3 deletions

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
SpeculativeHead,
get_linear,
FastRMSNorm,
FastLayerNorm,
@ -486,13 +486,13 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
self.model = Starcoder2Model(config, weights)
try:
self.lm_head = TensorParallelHead.load(
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
except RuntimeError:
self.lm_head = TensorParallelHead.load(
self.lm_head = SpeculativeHead.load(
config,
prefix="model.embed_tokens",
weights=weights,

View File

@ -29,6 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -51,6 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows
if config.sliding_window is not None: