fix: debug avoid scaling embed

This commit is contained in:
drbh 2024-05-10 03:44:51 +00:00 committed by Nicolas Patry
parent e13c08f57f
commit d503007fcf
2 changed files with 11 additions and 2 deletions

View File

@ -367,8 +367,15 @@ class FlashGemmaModel(torch.nn.Module):
prefix=pvalue,
weights=weights,
)
self.embed_tokens.weight = torch.nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
)
# TODO: avoid making a copy of the embedding matrix. added for debugging
self.unscaled_embed_tokens = torch.nn.Parameter(
self.embed_tokens.weight.clone()
)
# TODO: double check why this is needed
self.embed_tokens.weight *= embed_norm
self.layers = nn.ModuleList(

View File

@ -206,7 +206,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pixel_attention_mask=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = torch.nn.functional.embedding(
input_ids, self.language_model.model.unscaled_embed_tokens
)
if pixel_values is not None:
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)