From a0095b5b8dbcb335116417b7f68318f3421e26d8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Feb 2024 15:10:08 +0000 Subject: [PATCH] Fixing. --- .../text_generation_server/models/__init__.py | 18 +++++++++++------- .../custom_modeling/flash_gemma_modeling.py | 6 +++--- .../models/flash_causal_lm.py | 2 -- server/text_generation_server/utils/layers.py | 3 +-- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d9bc59f7..8edf0677 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -140,13 +140,17 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) - medusa_config = hf_hub_download( - medusa_model_id, revision=revision, filename="config.json" - ) - hf_hub_download( - medusa_model_id, revision=revision, filename="medusa_lm_head.pt" - ) - use_medusa = Path(medusa_config).parent + is_local = Path(medusa_model_id).exists() + if not is_local: + medusa_config = hf_hub_download( + medusa_model_id, revision=revision, filename="config.json" + ) + hf_hub_download( + medusa_model_id, revision=revision, filename="medusa_lm_head.pt" + ) + use_medusa = Path(medusa_config).parent + else: + use_medusa = Path(medusa_model_id) method = "medusa" else: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index d7bedf72..e91927df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, position_ids, @@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - return logits + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1276fefa..988637d4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -968,8 +968,6 @@ class FlashCausalLM(Model): speculative_logits, ) - logger.info(f"Accepted ids {accepted_ids}") - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d923ebfc..209f1c8a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -444,11 +444,10 @@ class SpeculativeHead(nn.Module): import json medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") with open(medusa_config, "r") as f: config = json.load(f) - filename = medusa_head[: -len(".pt")] + ".safetensors" routing = weights.routing with safe_open(filename, framework="pytorch") as f: for k in f.keys():