fix the crash of meta-llama/Llama-3.2-1B (#2918)

* fix the crash of meta-llama/Llama-3.2-1B

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review

Simpler fix (which doesn't break vlms).

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Wang, Yi 2025-01-17 22:50:58 +08:00 committed by GitHub
parent c20025dbf7
commit 6e982f43a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -642,9 +642,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
embedding_multiplier = getattr(config, "embedding_multiplier", None) embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None: if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier self.embed_tokens.weight.data *= embedding_multiplier
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,