mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Apply suggestions from code review
Simpler fix (which doesn't break vlms).
This commit is contained in:
parent
0a48e5624c
commit
3e4ca5032b
@ -634,18 +634,15 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
prefix = "model.embed_tokens"
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
suffix = "lm_head"
|
||||
prefix = (
|
||||
"lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||
)
|
||||
|
||||
# Used in Granite
|
||||
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||
if embedding_multiplier is not None:
|
||||
self.embed_tokens.weight.data *= embedding_multiplier
|
||||
|
||||
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||
with no_fp8(weights):
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
|
Loading…
Reference in New Issue
Block a user