diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 816fb196..3d700610 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -367,8 +367,15 @@ class FlashGemmaModel(torch.nn.Module): prefix=pvalue, weights=weights, ) + self.embed_tokens.weight = torch.nn.Parameter( + self.embed_tokens.weight[: config.vocab_size, : config.hidden_size] + ) + + # TODO: avoid making a copy of the embedding matrix. added for debugging + self.unscaled_embed_tokens = torch.nn.Parameter( + self.embed_tokens.weight.clone() + ) - # TODO: double check why this is needed self.embed_tokens.weight *= embed_norm self.layers = nn.ModuleList( diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 3e33032a..bb0a55cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -206,7 +206,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, pixel_attention_mask=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.language_model.model.embed_tokens(input_ids) + inputs_embeds = torch.nn.functional.embedding( + input_ids, self.language_model.model.unscaled_embed_tokens + ) if pixel_values is not None: pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)