From 6e982f43a1056ed8a04bdbe9a8345990bce0fb13 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 17 Jan 2025 22:50:58 +0800 Subject: [PATCH] fix the crash of meta-llama/Llama-3.2-1B (#2918) * fix the crash of meta-llama/Llama-3.2-1B Signed-off-by: Wang, Yi A * Apply suggestions from code review Simpler fix (which doesn't break vlms). --------- Signed-off-by: Wang, Yi A Co-authored-by: Nicolas Patry --- .../models/custom_modeling/flash_llama_modeling.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 28db42fea..54ec7ac29 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -642,9 +642,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): embedding_multiplier = getattr(config, "embedding_multiplier", None) if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - - prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" - + prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config,