mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Backport changes in medusa.
This commit is contained in:
parent
de11fc064a
commit
027e1dabcd
@ -71,19 +71,22 @@ class MedusaHeadV1(nn.Module):
|
||||
|
||||
speculator = config.speculator
|
||||
|
||||
medusa_config = str(Path(speculator) / "config.json")
|
||||
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||
path = speculator["path"]
|
||||
medusa_config = str(Path(path) / "config.json")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
for fname in speculator["model_paths"]:
|
||||
filename = str(Path(path) / fname)
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, medusa_config, weights)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
|
Loading…
Reference in New Issue
Block a user