From 7a9998d47c881233ee700bf0d9201c357ad35add Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Feb 2024 12:32:46 +0000 Subject: [PATCH] Remove the old logic. --- .../text_generation_server/models/__init__.py | 12 ++++++- .../custom_modeling/flash_llama_modeling.py | 6 ++-- .../models/flash_llama.py | 31 ------------------- server/text_generation_server/utils/layers.py | 19 ++---------- 4 files changed, 16 insertions(+), 52 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index dedbb7e2..d9bc59f7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -3,7 +3,9 @@ import torch from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto +from huggingface_hub import hf_hub_download from typing import Optional +from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -121,7 +123,7 @@ def get_model( use_medusa = None if "medusa_num_heads" in config_dict: - use_medusa = model_id + medusa_model_id = model_id model_id = config_dict["base_model_name_or_path"] revision = "main" speculate_medusa = config_dict["medusa_num_heads"] @@ -138,6 +140,14 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + medusa_config = hf_hub_download( + medusa_model_id, revision=revision, filename="config.json" + ) + hf_hub_download( + medusa_model_id, revision=revision, filename="medusa_lm_head.pt" + ) + use_medusa = Path(medusa_config).parent + method = "medusa" else: method = "n-gram" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 88b3d9d2..3a269fc0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, position_ids, @@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - return logits + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 8c2c1086..a2ac759a 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -68,37 +68,6 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashLlamaForCausalLM(config, weights) - if use_medusa: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 35bfbcba..c707b06d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -438,25 +438,10 @@ class SpeculativeHead(nn.Module): use_medusa = config.use_medusa if use_medusa: from pathlib import Path - from huggingface_hub import hf_hub_download - from text_generation_server.utils.weights import Weights from safetensors import safe_open import json - import os - is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") with open(medusa_config, "r") as f: config = json.load(f)