fix the crash of meta-llama/Llama-3.2-1B

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-01-16 04:23:16 -08:00
parent 922cc38fbc
commit 0a48e5624c

View File

@ -634,17 +634,18 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
suffix = "model.embed_tokens" prefix = "model.embed_tokens"
else: else:
suffix = "lm_head" suffix = "lm_head"
prefix = (
"lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
)
# Used in Granite # Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None) embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None: if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier self.embed_tokens.weight.data *= embedding_multiplier
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,