From 0a48e5624c014ad6b44b8ff18c201fe502c1a31c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 16 Jan 2025 04:23:16 -0800 Subject: [PATCH] fix the crash of meta-llama/Llama-3.2-1B Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_llama_modeling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 28db42fe..eaef3781 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -634,17 +634,18 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) if config.tie_word_embeddings: - suffix = "model.embed_tokens" + prefix = "model.embed_tokens" else: suffix = "lm_head" + prefix = ( + "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" + ) # Used in Granite embedding_multiplier = getattr(config, "embedding_multiplier", None) if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" - with no_fp8(weights): self.lm_head = SpeculativeHead.load( config,