mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-06 17:52:07 +00:00
fix idefics
This commit is contained in:
parent
b86919a87a
commit
52e4186c2a
@ -116,11 +116,10 @@ class MistralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
if hasattr(config, "head_dim"):
|
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
else:
|
else:
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
|
@ -733,6 +733,97 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
pixel_attention_mask: torch.BoolTensor,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert pixel_values is not None
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||||
|
)
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||||
|
|
||||||
|
def get_input_embeds(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: torch.Tensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
pixel_attention_mask: torch.BoolTensor = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if vision_embeds is None and pixel_values is not None:
|
||||||
|
vision_embeds = self.get_vision_embeds(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if vision_embeds is not None:
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, vision_embeds
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -751,76 +842,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
image_sizes: Optional[torch.Tensor] = None,
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
|
||||||
all_states = []
|
|
||||||
all_pixel_values = pixel_values
|
|
||||||
all_pixel_mask = pixel_attention_mask
|
|
||||||
for i in range(batch_size):
|
|
||||||
pixel_values = all_pixel_values.to(
|
|
||||||
dtype=self.dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
pixel_values = pixel_values[i : i + 1]
|
|
||||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(
|
|
||||||
dim=(-1, -2, -3)
|
|
||||||
) != nb_values_per_image
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=(
|
|
||||||
pixel_values.size(0),
|
|
||||||
pixel_values.size(2),
|
|
||||||
pixel_values.size(3),
|
|
||||||
),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask/pP p
|
|
||||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[
|
|
||||||
real_images_inds
|
|
||||||
].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
|
||||||
dimension=1, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(
|
|
||||||
dimension=2, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(
|
|
||||||
image_hidden_states,
|
|
||||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
|
||||||
)
|
|
||||||
all_states.append(image_hidden_states)
|
|
||||||
image_hidden_states = torch.stack(all_states, dim=0)
|
|
||||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
|
||||||
# that simply don't exist
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
input_ids, inputs_embeds, image_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1300,7 +1300,7 @@ class FlashCausalLM(Model):
|
|||||||
if head_size is None:
|
if head_size is None:
|
||||||
# Some models use GQA and different sizes for o_proj
|
# Some models use GQA and different sizes for o_proj
|
||||||
# and q_proj, that allows for that.
|
# and q_proj, that allows for that.
|
||||||
if hasattr(config, "head_dim"):
|
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
else:
|
else:
|
||||||
self.head_size = config.hidden_size // config.num_attention_heads
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
@ -1899,9 +1899,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.prepare_for_prefill()
|
batch.prepare_for_prefill()
|
||||||
|
|
||||||
self.get_input_embeddings(batch)
|
self.get_input_embeddings(batch)
|
||||||
from pdb import set_trace
|
|
||||||
|
|
||||||
set_trace()
|
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
|
@ -574,7 +574,8 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def get_vision_embeds(self, pixel_values, image_sizes=None):
|
def get_vision_embeds(self, pixel_values, **kwargs):
|
||||||
|
image_sizes = kwargs.get("image_sizes", None)
|
||||||
image_features = self.model.get_image_features(
|
image_features = self.model.get_image_features(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
|
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
|
||||||
@ -586,10 +587,10 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
||||||
return projected_vision_flat
|
return projected_vision_flat
|
||||||
|
|
||||||
def get_input_embeds(self, input_ids, vision_embeddings=None):
|
def get_input_embeds(self, input_ids, vision_embeds=None):
|
||||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
if vision_embeddings is not None:
|
if vision_embeds is not None:
|
||||||
original_inputs_embeds_shape = inputs_embeds.shape
|
original_inputs_embeds_shape = inputs_embeds.shape
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||||
-1
|
-1
|
||||||
@ -600,17 +601,15 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||||
num_tokens_to_fill = final_mask_1d.sum()
|
num_tokens_to_fill = final_mask_1d.sum()
|
||||||
|
|
||||||
if num_tokens_to_fill != vision_embeddings.size(0):
|
if num_tokens_to_fill != vision_embeds.size(0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||||
f"but multi_modal_projector returned {vision_embeddings.size(0)}"
|
f"but multi_modal_projector returned {vision_embeds.size(0)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||||
-1, inputs_embeds.size(-1)
|
-1, inputs_embeds.size(-1)
|
||||||
)
|
)
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
|
||||||
expanded_mask, vision_embeddings
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -376,7 +376,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image = image
|
image = image
|
||||||
else:
|
else:
|
||||||
image = [image]
|
image = [image]
|
||||||
|
|
||||||
pixel_values = processor.image_processor(
|
pixel_values = processor.image_processor(
|
||||||
[image], return_tensors="pt", **kwargs
|
[image], return_tensors="pt", **kwargs
|
||||||
)
|
)
|
||||||
@ -387,9 +386,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
if len(image_inputs) > 0:
|
if len(image_inputs) > 0:
|
||||||
batch_image_inputs[i] = image_inputs
|
batch_image_inputs[i] = image_inputs
|
||||||
from pdb import set_trace
|
# pixel_values = processor.image_processor(
|
||||||
|
# all_images, return_tensors="pt", **kwargs
|
||||||
|
# )
|
||||||
|
|
||||||
set_trace()
|
|
||||||
batch_image_positions = []
|
batch_image_positions = []
|
||||||
batch_tokenized_inputs = []
|
batch_tokenized_inputs = []
|
||||||
max_length = 0
|
max_length = 0
|
||||||
@ -410,6 +410,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
if config.model_type == "gemma3":
|
if config.model_type == "gemma3":
|
||||||
image_text = image_text.replace("\n\n", "")
|
image_text = image_text.replace("\n\n", "")
|
||||||
|
elif config.model_type == "idefics2":
|
||||||
|
image_text = image_text_replacement_fixup(config, image_text)
|
||||||
|
|
||||||
image_tokens.append(
|
image_tokens.append(
|
||||||
(
|
(
|
||||||
@ -434,9 +436,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
prev = 0
|
prev = 0
|
||||||
image_positions = []
|
image_positions = []
|
||||||
|
idefics_replacement = False
|
||||||
|
|
||||||
|
config.image_token_index = (
|
||||||
|
config.image_token_index
|
||||||
|
if hasattr(config, "image_token_index")
|
||||||
|
else config.image_token_id
|
||||||
|
)
|
||||||
for image_id, id_token, tokens in image_tokens:
|
for image_id, id_token, tokens in image_tokens:
|
||||||
id_token = tokenizer.get_vocab()[id_token]
|
id_token = tokenizer.get_vocab()[id_token]
|
||||||
|
|
||||||
|
if config.model_type == "idefics2" and idefics_replacement:
|
||||||
|
id_token = config.image_token_index
|
||||||
|
tokens = tokens[1:]
|
||||||
|
|
||||||
length = len(tokens)
|
length = len(tokens)
|
||||||
index = input_ids[prev:].index(id_token)
|
index = input_ids[prev:].index(id_token)
|
||||||
|
|
||||||
@ -463,6 +476,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image_positions.append(pos)
|
image_positions.append(pos)
|
||||||
prev = index + length
|
prev = index + length
|
||||||
|
|
||||||
|
if config.model_type == "idefics2" and prev != len(input_ids):
|
||||||
|
if input_ids[prev] == config.image_token_id:
|
||||||
|
# means this is replaced by image_text_replacement_fixup
|
||||||
|
idefics_replacement = True
|
||||||
|
else:
|
||||||
|
idefics_replacement = False
|
||||||
|
|
||||||
if len(image_positions) > 0:
|
if len(image_positions) > 0:
|
||||||
batch_image_positions[i] = image_positions
|
batch_image_positions[i] = image_positions
|
||||||
|
|
||||||
@ -540,40 +560,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
self.image_inputs[i][j] = None
|
self.image_inputs[i][j] = None
|
||||||
|
|
||||||
if self.has_image and len(scheduled_image_pixel_values):
|
if self.has_image and len(scheduled_image_pixel_values):
|
||||||
self.pixel_values = torch.cat(
|
self.pixel_values = [
|
||||||
[d["pixel_values"] for d in scheduled_image_pixel_values], dim=0
|
d["pixel_values"].to(device) for d in scheduled_image_pixel_values
|
||||||
).to(device)
|
]
|
||||||
|
|
||||||
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||||
self.pixel_attention_mask = torch.cat(
|
self.pixel_attention_mask = [
|
||||||
[d["pixel_attention_mask"] for d in scheduled_image_pixel_values],
|
d["pixel_attention_mask"].to(device)
|
||||||
dim=0,
|
for d in scheduled_image_pixel_values
|
||||||
).to(device)
|
]
|
||||||
|
|
||||||
if "image_sizes" in scheduled_image_pixel_values[0]:
|
if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||||
self.image_sizes = torch.cat(
|
self.image_sizes = [
|
||||||
[d["image_sizes"] for d in scheduled_image_pixel_values], dim=0
|
d["image_sizes"].to(device) for d in scheduled_image_pixel_values
|
||||||
).to(device)
|
]
|
||||||
|
|
||||||
if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||||
self.image_grid_thw = torch.cat(
|
self.image_grid_thw = [
|
||||||
[d["image_grid_thw"] for d in scheduled_image_pixel_values], dim=0
|
d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
|
||||||
).to(device)
|
]
|
||||||
else:
|
else:
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.pixel_attention_mask = None
|
self.pixel_attention_mask = None
|
||||||
self.image_sizes = None
|
self.image_sizes = None
|
||||||
self.image_grid_thw = None
|
self.image_grid_thw = None
|
||||||
|
|
||||||
def update_encoder_cache(self, encoder_outputs):
|
def update_encoder_cache(self, encoder_outputs, batch, input):
|
||||||
prev = 0
|
self.encoder_cache[batch][input.id] = scatter_image_embeds(
|
||||||
for i, input in self.scheduled_image_input:
|
encoder_outputs, input.is_embed
|
||||||
length = input.num_placeholder_tokens
|
)
|
||||||
self.encoder_cache[i][input.id] = scatter_image_embeds(
|
|
||||||
encoder_outputs[prev : prev + length], input.is_embed
|
|
||||||
)
|
|
||||||
|
|
||||||
prev = prev + length
|
|
||||||
|
|
||||||
def get_mm_embeddings(self):
|
def get_mm_embeddings(self):
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
@ -675,22 +690,75 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return self.batch_class
|
return self.batch_class
|
||||||
|
|
||||||
|
def get_vision_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
pixel_attention_mask: torch.Tensor,
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
image_grid_thw: torch.Tensor,
|
||||||
|
):
|
||||||
|
embeds = self.model.get_vision_embeds(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
def get_input_embeds(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
vision_embeds: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
return self.model.get_input_embeds(
|
||||||
|
input_ids=input_ids,
|
||||||
|
vision_embeds=vision_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
def get_mm_embeddings(self, batch):
|
def get_mm_embeddings(self, batch):
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
encoder_outputs = self.get_vision_embeds(batch.pixel_values)
|
for i, image_input in batch.scheduled_image_input:
|
||||||
batch.update_encoder_cache(encoder_outputs)
|
from pdb import set_trace
|
||||||
batch.pixel_values = None
|
|
||||||
|
set_trace()
|
||||||
|
pixel_values = batch.pixel_values[i]
|
||||||
|
pixel_attention_mask = (
|
||||||
|
batch.pixel_attention_mask[i]
|
||||||
|
if batch.pixel_attention_mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
image_sizes = (
|
||||||
|
batch.image_sizes[i] if batch.image_sizes is not None else None
|
||||||
|
)
|
||||||
|
image_grid_thw = (
|
||||||
|
batch.image_grid_thw[i]
|
||||||
|
if batch.image_grid_thw is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.get_vision_embeds(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
)
|
||||||
|
batch.update_encoder_cache(encoder_outputs, i, image_input)
|
||||||
|
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
batch.image_grid_thw = None
|
||||||
return batch.get_mm_embeddings()
|
return batch.get_mm_embeddings()
|
||||||
|
|
||||||
def get_input_embeddings(self, batch):
|
def get_input_embeddings(self, batch):
|
||||||
if batch.has_image:
|
if batch.has_image:
|
||||||
vision_embeddings = self.get_mm_embeddings(batch)
|
vision_embeds = self.get_mm_embeddings(batch)
|
||||||
batch.has_image = False
|
batch.has_image = False
|
||||||
else:
|
else:
|
||||||
vision_embeddings = None
|
vision_embeds = None
|
||||||
|
|
||||||
inputs_embeds = self.get_input_embeds(
|
inputs_embeds = self.get_input_embeds(
|
||||||
batch.input_ids, vision_embeddings=vision_embeddings
|
batch.input_ids, vision_embeds=vision_embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.inputs_embeds = inputs_embeds
|
batch.inputs_embeds = inputs_embeds
|
||||||
|
Loading…
Reference in New Issue
Block a user