mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: typo and lint
This commit is contained in:
parent
36fb4b5a7a
commit
4df1b25ddb
@ -198,18 +198,20 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
# TODO: avoid these casts upstream
|
# TODO: avoid these casts upstream
|
||||||
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||||
|
|
||||||
# TODO: now we scale them? maybe we can do this up or downstream
|
# TODO: now we scale them? maybe we can do this up or downstream
|
||||||
scaled_image_features = image_features / (2048**0.5)
|
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
||||||
|
|
||||||
# mask where image or padding tokens
|
# mask where image or padding tokens
|
||||||
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
||||||
|
|
||||||
# insert image features into input embeddings
|
# insert image features into input embeddings
|
||||||
inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1])
|
inputs_embeds[mask] = scaled_image_features.view(
|
||||||
|
-1, scaled_image_features.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
if input_ids.size(0) != 3000:
|
if input_ids.size(0) != 3000:
|
||||||
# import ipdb
|
# import ipdb
|
||||||
@ -219,7 +221,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
# NOTE: scale back up since we dont normalize inside the model like transformers
|
# NOTE: scale back up since we dont normalize inside the model like transformers
|
||||||
# TODO: simplify all the rescaling
|
# TODO: simplify all the rescaling
|
||||||
inputs_embeds = inputs_embeds * (2048**0.5)
|
inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5)
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
Loading…
Reference in New Issue
Block a user