Remove the old logic.

This commit is contained in:
Nicolas Patry 2024-02-22 12:32:46 +00:00
parent 21b3072288
commit 7a9998d47c
4 changed files with 16 additions and 52 deletions

View File

@ -3,7 +3,9 @@ import torch
from loguru import logger from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download
from typing import Optional from typing import Optional
from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
@ -121,7 +123,7 @@ def get_model(
use_medusa = None use_medusa = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
use_medusa = model_id medusa_model_id = model_id
model_id = config_dict["base_model_name_or_path"] model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
@ -138,6 +140,14 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
medusa_config = hf_hub_download(
medusa_model_id, revision=revision, filename="config.json"
)
hf_hub_download(
medusa_model_id, revision=revision, filename="medusa_lm_head.pt"
)
use_medusa = Path(medusa_config).parent
method = "medusa" method = "medusa"
else: else:
method = "n-gram" method = "n-gram"

View File

@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits return logits, speculative_logits

View File

@ -68,37 +68,6 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,

View File

@ -438,23 +438,8 @@ class SpeculativeHead(nn.Module):
use_medusa = config.use_medusa use_medusa = config.use_medusa
if use_medusa: if use_medusa:
from pathlib import Path from pathlib import Path
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import Weights
from safetensors import safe_open from safetensors import safe_open
import json import json
import os
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")