mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
swap load
This commit is contained in:
parent
308d7bcb3d
commit
68717f8716
@ -479,7 +479,6 @@ class MedusaHeadV1(nn.Module):
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
|
||||||
use_medusa = config.use_medusa
|
use_medusa = config.use_medusa
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
@ -497,6 +496,7 @@ class MedusaHeadV1(nn.Module):
|
|||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
|
||||||
medusa = MedusaModel(config, medusa_config, weights)
|
medusa = MedusaModel(config, medusa_config, weights)
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
return MedusaHeadV1(lm_head, medusa)
|
return MedusaHeadV1(lm_head, medusa)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -603,9 +603,9 @@ class SpeculativeHead(nn.Module):
|
|||||||
if use_medusa:
|
if use_medusa:
|
||||||
lm_head = None
|
lm_head = None
|
||||||
try:
|
try:
|
||||||
medusa = MedusaHeadV2(config, prefix, weights)
|
|
||||||
except:
|
|
||||||
medusa = MedusaHeadV1.load(config, prefix, weights)
|
medusa = MedusaHeadV1.load(config, prefix, weights)
|
||||||
|
except:
|
||||||
|
medusa = MedusaHeadV2(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
medusa = None
|
medusa = None
|
||||||
|
Loading…
Reference in New Issue
Block a user