mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Backport changes in medusa.
This commit is contained in:
parent
de11fc064a
commit
027e1dabcd
@ -71,19 +71,22 @@ class MedusaHeadV1(nn.Module):
|
|||||||
|
|
||||||
speculator = config.speculator
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(speculator) / "config.json")
|
path = speculator["path"]
|
||||||
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
medusa_config = str(Path(path) / "config.json")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
for fname in speculator["model_paths"]:
|
||||||
medusa_config = json.load(f)
|
filename = str(Path(path) / fname)
|
||||||
routing = weights.routing
|
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with open(medusa_config, "r") as f:
|
||||||
for k in f.keys():
|
medusa_config = json.load(f)
|
||||||
if k in routing and routing[k] != filename:
|
routing = weights.routing
|
||||||
raise RuntimeError(
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
for k in f.keys():
|
||||||
)
|
if k in routing and routing[k] != filename:
|
||||||
routing[k] = filename
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
medusa = MedusaModel(config, medusa_config, weights)
|
medusa = MedusaModel(config, medusa_config, weights)
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
Loading…
Reference in New Issue
Block a user