diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index eaef3781..54ec7ac2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -634,18 +634,15 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) if config.tie_word_embeddings: - prefix = "model.embed_tokens" + suffix = "model.embed_tokens" else: suffix = "lm_head" - prefix = ( - "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" - ) # Used in Granite embedding_multiplier = getattr(config, "embedding_multiplier", None) if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - + prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config,