mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Remove the old logic.
This commit is contained in:
parent
21b3072288
commit
7a9998d47c
@ -3,7 +3,9 @@ import torch
|
||||
from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
@ -121,7 +123,7 @@ def get_model(
|
||||
|
||||
use_medusa = None
|
||||
if "medusa_num_heads" in config_dict:
|
||||
use_medusa = model_id
|
||||
medusa_model_id = model_id
|
||||
model_id = config_dict["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
speculate_medusa = config_dict["medusa_num_heads"]
|
||||
@ -138,6 +140,14 @@ def get_model(
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
medusa_config = hf_hub_download(
|
||||
medusa_model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
hf_hub_download(
|
||||
medusa_model_id, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
use_medusa = Path(medusa_config).parent
|
||||
|
||||
method = "medusa"
|
||||
else:
|
||||
method = "n-gram"
|
||||
|
@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -68,37 +68,6 @@ class FlashLlama(FlashCausalLM):
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
@ -438,23 +438,8 @@ class SpeculativeHead(nn.Module):
|
||||
use_medusa = config.use_medusa
|
||||
if use_medusa:
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
import os
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user