fix the crash of meta-llama/Llama-3.2-1B

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-01-16 04:23:16 -08:00
parent 922cc38fbc
commit 0a48e5624c

View File

@ -634,17 +634,18 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights,
)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
prefix = "model.embed_tokens"
else:
suffix = "lm_head"
prefix = (
"lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
)
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,