mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust image and text merge logic
This commit is contained in:
parent
d503007fcf
commit
36fb4b5a7a
@ -175,21 +175,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(
|
|
||||||
self, image_features, inputs_embeds, input_ids
|
|
||||||
):
|
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
|
||||||
mask = input_ids == self.config.image_token_index
|
|
||||||
# Let's pray we have enabled enough slots !
|
|
||||||
try:
|
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -210,20 +195,21 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
input_ids, self.language_model.model.unscaled_embed_tokens
|
input_ids, self.language_model.model.unscaled_embed_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
|
|
||||||
# merge text and images
|
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
# TODO: avoid these casts upstream
|
||||||
|
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
selected_image_feature = image_outputs.last_hidden_state
|
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
|
||||||
# NOTE: image_features returns the exact values as transformers
|
|
||||||
|
|
||||||
# TODO: correctly merge inputs_embeds with image_features
|
# TODO: now we scale them? maybe we can do this up or downstream
|
||||||
merged_inputs_embeds = self._merge_input_ids_with_image_features(
|
scaled_image_features = image_features / (2048**0.5)
|
||||||
image_features, inputs_embeds, input_ids
|
|
||||||
)
|
# mask where image or padding tokens
|
||||||
|
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
||||||
|
|
||||||
|
# insert image features into input embeddings
|
||||||
|
inputs_embeds[mask] = scaled_image_features.view(-1, scaled_image_features.shape[-1])
|
||||||
|
|
||||||
if input_ids.size(0) != 3000:
|
if input_ids.size(0) != 3000:
|
||||||
# import ipdb
|
# import ipdb
|
||||||
@ -231,6 +217,10 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
# ipdb.set_trace()
|
# ipdb.set_trace()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# NOTE: scale back up since we dont normalize inside the model like transformers
|
||||||
|
# TODO: simplify all the rescaling
|
||||||
|
inputs_embeds = inputs_embeds * (2048**0.5)
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -242,6 +232,10 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if input_ids.size(0) != 3000:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
pass
|
||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||||
|
Loading…
Reference in New Issue
Block a user