mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Apply suggestions from code review
Simpler fix (which doesn't break vlms).
This commit is contained in:
parent
0a48e5624c
commit
3e4ca5032b
@ -634,18 +634,15 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
prefix = "model.embed_tokens"
|
suffix = "model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
suffix = "lm_head"
|
suffix = "lm_head"
|
||||||
prefix = (
|
|
||||||
"lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Used in Granite
|
# Used in Granite
|
||||||
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||||
if embedding_multiplier is not None:
|
if embedding_multiplier is not None:
|
||||||
self.embed_tokens.weight.data *= embedding_multiplier
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
|
Loading…
Reference in New Issue
Block a user