diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 28db42fe..eaef3781 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -634,17 +634,18 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) if config.tie_word_embeddings: - suffix = "model.embed_tokens" + prefix = "model.embed_tokens" else: suffix = "lm_head" + prefix = ( + "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" + ) # Used in Granite embedding_multiplier = getattr(config, "embedding_multiplier", None) if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" - with no_fp8(weights): self.lm_head = SpeculativeHead.load( config,