mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add medusa
This commit is contained in:
parent
a56bd736e6
commit
c84223590b
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
@ -486,13 +486,13 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||||||
|
|
||||||
self.model = Starcoder2Model(config, weights)
|
self.model = Starcoder2Model(config, weights)
|
||||||
try:
|
try:
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens",
|
prefix="model.embed_tokens",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -29,6 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -51,6 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user