diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ccba1728..4cc2245d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -479,7 +479,6 @@ class MedusaHeadV1(nn.Module): from safetensors import safe_open import json - lm_head = TensorParallelHead.load(config, prefix, weights) use_medusa = config.use_medusa medusa_config = str(Path(use_medusa) / "config.json") @@ -497,6 +496,7 @@ class MedusaHeadV1(nn.Module): routing[k] = filename medusa = MedusaModel(config, medusa_config, weights) + lm_head = TensorParallelHead.load(config, prefix, weights) return MedusaHeadV1(lm_head, medusa) def forward( @@ -603,9 +603,9 @@ class SpeculativeHead(nn.Module): if use_medusa: lm_head = None try: - medusa = MedusaHeadV2(config, prefix, weights) - except: medusa = MedusaHeadV1.load(config, prefix, weights) + except: + medusa = MedusaHeadV2(config, prefix, weights) else: lm_head = TensorParallelHead.load(config, prefix, weights) medusa = None