swap load

This commit is contained in:
OlivierDehaene 2024-04-12 16:11:47 +02:00
parent 308d7bcb3d
commit 68717f8716

View File

@ -479,7 +479,6 @@ class MedusaHeadV1(nn.Module):
from safetensors import safe_open
import json
lm_head = TensorParallelHead.load(config, prefix, weights)
use_medusa = config.use_medusa
medusa_config = str(Path(use_medusa) / "config.json")
@ -497,6 +496,7 @@ class MedusaHeadV1(nn.Module):
routing[k] = filename
medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MedusaHeadV1(lm_head, medusa)
def forward(
@ -603,9 +603,9 @@ class SpeculativeHead(nn.Module):
if use_medusa:
lm_head = None
try:
medusa = MedusaHeadV2(config, prefix, weights)
except:
medusa = MedusaHeadV1.load(config, prefix, weights)
except:
medusa = MedusaHeadV2(config, prefix, weights)
else:
lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None