mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
swap load
This commit is contained in:
parent
308d7bcb3d
commit
68717f8716
@ -479,7 +479,6 @@ class MedusaHeadV1(nn.Module):
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
@ -497,6 +496,7 @@ class MedusaHeadV1(nn.Module):
|
||||
routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, medusa_config, weights)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
return MedusaHeadV1(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
@ -603,9 +603,9 @@ class SpeculativeHead(nn.Module):
|
||||
if use_medusa:
|
||||
lm_head = None
|
||||
try:
|
||||
medusa = MedusaHeadV2(config, prefix, weights)
|
||||
except:
|
||||
medusa = MedusaHeadV1.load(config, prefix, weights)
|
||||
except:
|
||||
medusa = MedusaHeadV2(config, prefix, weights)
|
||||
else:
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
medusa = None
|
||||
|
Loading…
Reference in New Issue
Block a user