mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: debug avoid scaling embed
This commit is contained in:
parent
e13c08f57f
commit
d503007fcf
@ -367,8 +367,15 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
prefix=pvalue,
|
prefix=pvalue,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.embed_tokens.weight = torch.nn.Parameter(
|
||||||
|
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: avoid making a copy of the embedding matrix. added for debugging
|
||||||
|
self.unscaled_embed_tokens = torch.nn.Parameter(
|
||||||
|
self.embed_tokens.weight.clone()
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: double check why this is needed
|
|
||||||
self.embed_tokens.weight *= embed_norm
|
self.embed_tokens.weight *= embed_norm
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
@ -206,7 +206,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
inputs_embeds = torch.nn.functional.embedding(
|
||||||
|
input_ids, self.language_model.model.unscaled_embed_tokens
|
||||||
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user