mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: debug avoid scaling embed
This commit is contained in:
parent
e13c08f57f
commit
d503007fcf
@ -367,8 +367,15 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
prefix=pvalue,
|
||||
weights=weights,
|
||||
)
|
||||
self.embed_tokens.weight = torch.nn.Parameter(
|
||||
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
|
||||
)
|
||||
|
||||
# TODO: avoid making a copy of the embedding matrix. added for debugging
|
||||
self.unscaled_embed_tokens = torch.nn.Parameter(
|
||||
self.embed_tokens.weight.clone()
|
||||
)
|
||||
|
||||
# TODO: double check why this is needed
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
@ -206,7 +206,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
pixel_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
inputs_embeds = torch.nn.functional.embedding(
|
||||
input_ids, self.language_model.model.unscaled_embed_tokens
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user