From 3e4ca5032b72d8996174e4a454b75d7b58cd4f09 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 17 Jan 2025 15:43:34 +0100 Subject: [PATCH] Apply suggestions from code review Simpler fix (which doesn't break vlms). --- .../models/custom_modeling/flash_llama_modeling.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index eaef3781..54ec7ac2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -634,18 +634,15 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) if config.tie_word_embeddings: - prefix = "model.embed_tokens" + suffix = "model.embed_tokens" else: suffix = "lm_head" - prefix = ( - "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" - ) # Used in Granite embedding_multiplier = getattr(config, "embedding_multiplier", None) if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier - + prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config,