Backport changes in medusa.

This commit is contained in:
Nicolas Patry 2024-05-13 13:18:29 +00:00
parent de11fc064a
commit 027e1dabcd

View File

@ -71,19 +71,22 @@ class MedusaHeadV1(nn.Module):
speculator = config.speculator
medusa_config = str(Path(speculator) / "config.json")
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
path = speculator["path"]
medusa_config = str(Path(path) / "config.json")
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
for fname in speculator["model_paths"]:
filename = str(Path(path) / fname)
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)