mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fix gemma + medusa.
This commit is contained in:
parent
a0095b5b8d
commit
ed95f1982d
@ -60,36 +60,6 @@ class FlashGemma(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashGemmaForCausalLM(config, weights)
|
model = FlashGemmaForCausalLM(config, weights)
|
||||||
if use_medusa:
|
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
is_local_model = (
|
|
||||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
||||||
weights = Weights(
|
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
lm_head = model.lm_head
|
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashGemma, self).__init__(
|
super(FlashGemma, self).__init__(
|
||||||
|
Loading…
Reference in New Issue
Block a user