diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index b7f2aaf6..2e9a010f 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -71,19 +71,22 @@ class MedusaHeadV1(nn.Module): speculator = config.speculator - medusa_config = str(Path(speculator) / "config.json") - filename = str(Path(speculator) / "medusa_lm_head.safetensors") + path = speculator["path"] + medusa_config = str(Path(path) / "config.json") - with open(medusa_config, "r") as f: - medusa_config = json.load(f) - routing = weights.routing - with safe_open(filename, framework="pytorch") as f: - for k in f.keys(): - if k in routing and routing[k] != filename: - raise RuntimeError( - f"Key {k} was found in multiple files: {filename} and {routing[k]}" - ) - routing[k] = filename + for fname in speculator["model_paths"]: + filename = str(Path(path) / fname) + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename medusa = MedusaModel(config, medusa_config, weights) lm_head = TensorParallelHead.load(config, prefix, weights)