Apply suggestions from code review

Simpler fix (which doesn't break vlms).
This commit is contained in:
Nicolas Patry 2025-01-17 15:43:34 +01:00 committed by GitHub
parent 0a48e5624c
commit 3e4ca5032b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -634,18 +634,15 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
prefix = "model.embed_tokens" suffix = "model.embed_tokens"
else: else:
suffix = "lm_head" suffix = "lm_head"
prefix = (
"lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
)
# Used in Granite # Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None) embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None: if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier self.embed_tokens.weight.data *= embedding_multiplier
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,