add medusa

This commit is contained in:
OlivierDehaene 2024-02-28 11:02:39 +01:00
parent a56bd736e6
commit c84223590b
2 changed files with 5 additions and 3 deletions

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
FastLayerNorm, FastLayerNorm,
@ -486,13 +486,13 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
self.model = Starcoder2Model(config, weights) self.model = Starcoder2Model(config, weights)
try: try:
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,
) )
except RuntimeError: except RuntimeError:
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens", prefix="model.embed_tokens",
weights=weights, weights=weights,

View File

@ -29,6 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,6 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None: