mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: typo and lint
This commit is contained in:
parent
36fb4b5a7a
commit
4df1b25ddb
@ -198,18 +198,20 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# TODO: avoid these casts upstream
|
||||
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
|
||||
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
|
||||
# TODO: now we scale them? maybe we can do this up or downstream
|
||||
scaled_image_features = image_features / (2048**0.5)
|
||||
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
||||
|
||||
# mask where image or padding tokens
|
||||
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
||||
|
||||
# insert image features into input embeddings
|
||||
inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1])
|
||||
inputs_embeds[mask] = scaled_image_features.view(
|
||||
-1, scaled_image_features.shape[-1]
|
||||
)
|
||||
|
||||
if input_ids.size(0) != 3000:
|
||||
# import ipdb
|
||||
@ -219,7 +221,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
|
||||
# NOTE: scale back up since we dont normalize inside the model like transformers
|
||||
# TODO: simplify all the rescaling
|
||||
inputs_embeds = inputs_embeds * (2048**0.5)
|
||||
inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
Loading…
Reference in New Issue
Block a user