diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 28db42fe..54ec7ac2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -642,9 +642,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): embedding_multiplier = getattr(config, "embedding_multiplier", None) if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - - prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" - + prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config,