mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Remove the old logic.
This commit is contained in:
parent
21b3072288
commit
7a9998d47c
@ -3,7 +3,9 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
@ -121,7 +123,7 @@ def get_model(
|
|||||||
|
|
||||||
use_medusa = None
|
use_medusa = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
use_medusa = model_id
|
medusa_model_id = model_id
|
||||||
model_id = config_dict["base_model_name_or_path"]
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
revision = "main"
|
revision = "main"
|
||||||
speculate_medusa = config_dict["medusa_num_heads"]
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
@ -138,6 +140,14 @@ def get_model(
|
|||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
medusa_model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
medusa_model_id, revision=revision, filename="medusa_lm_head.pt"
|
||||||
|
)
|
||||||
|
use_medusa = Path(medusa_config).parent
|
||||||
|
|
||||||
method = "medusa"
|
method = "medusa"
|
||||||
else:
|
else:
|
||||||
method = "n-gram"
|
method = "n-gram"
|
||||||
|
@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits
|
return logits, speculative_logits
|
||||||
|
@ -68,37 +68,6 @@ class FlashLlama(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
if use_medusa:
|
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
is_local_model = (
|
|
||||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
||||||
weights = Weights(
|
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
lm_head = model.lm_head
|
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -438,25 +438,10 @@ class SpeculativeHead(nn.Module):
|
|||||||
use_medusa = config.use_medusa
|
use_medusa = config.use_medusa
|
||||||
if use_medusa:
|
if use_medusa:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from text_generation_server.utils.weights import Weights
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
import os
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
is_local_model = (
|
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
Loading…
Reference in New Issue
Block a user