mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-05 17:22:06 +00:00
fix idefics
This commit is contained in:
parent
b86919a87a
commit
52e4186c2a
@ -116,11 +116,10 @@ class MistralAttention(torch.nn.Module):
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
if hasattr(config, "head_dim"):
|
||||
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
|
@ -733,6 +733,97 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: torch.BoolTensor,
|
||||
**kwargs,
|
||||
):
|
||||
assert pixel_values is not None
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||
)
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: torch.BoolTensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if vision_embeds is None and pixel_values is not None:
|
||||
vision_embeds = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, vision_embeds
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -751,76 +842,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(
|
||||
dtype=self.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||
)
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -1300,7 +1300,7 @@ class FlashCausalLM(Model):
|
||||
if head_size is None:
|
||||
# Some models use GQA and different sizes for o_proj
|
||||
# and q_proj, that allows for that.
|
||||
if hasattr(config, "head_dim"):
|
||||
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
@ -1899,9 +1899,6 @@ class FlashCausalLM(Model):
|
||||
batch.prepare_for_prefill()
|
||||
|
||||
self.get_input_embeddings(batch)
|
||||
from pdb import set_trace
|
||||
|
||||
set_trace()
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
|
@ -574,7 +574,8 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||
return inputs
|
||||
|
||||
def get_vision_embeds(self, pixel_values, image_sizes=None):
|
||||
def get_vision_embeds(self, pixel_values, **kwargs):
|
||||
image_sizes = kwargs.get("image_sizes", None)
|
||||
image_features = self.model.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
|
||||
@ -586,10 +587,10 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
||||
return projected_vision_flat
|
||||
|
||||
def get_input_embeds(self, input_ids, vision_embeddings=None):
|
||||
def get_input_embeds(self, input_ids, vision_embeds=None):
|
||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||
|
||||
if vision_embeddings is not None:
|
||||
if vision_embeds is not None:
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
@ -600,17 +601,15 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
|
||||
if num_tokens_to_fill != vision_embeddings.size(0):
|
||||
if num_tokens_to_fill != vision_embeds.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"but multi_modal_projector returned {vision_embeddings.size(0)}"
|
||||
f"but multi_modal_projector returned {vision_embeds.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||
-1, inputs_embeds.size(-1)
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, vision_embeddings
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||
return inputs_embeds
|
||||
|
@ -376,7 +376,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image = image
|
||||
else:
|
||||
image = [image]
|
||||
|
||||
pixel_values = processor.image_processor(
|
||||
[image], return_tensors="pt", **kwargs
|
||||
)
|
||||
@ -387,9 +386,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
if len(image_inputs) > 0:
|
||||
batch_image_inputs[i] = image_inputs
|
||||
from pdb import set_trace
|
||||
# pixel_values = processor.image_processor(
|
||||
# all_images, return_tensors="pt", **kwargs
|
||||
# )
|
||||
|
||||
set_trace()
|
||||
batch_image_positions = []
|
||||
batch_tokenized_inputs = []
|
||||
max_length = 0
|
||||
@ -410,6 +410,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
if config.model_type == "gemma3":
|
||||
image_text = image_text.replace("\n\n", "")
|
||||
elif config.model_type == "idefics2":
|
||||
image_text = image_text_replacement_fixup(config, image_text)
|
||||
|
||||
image_tokens.append(
|
||||
(
|
||||
@ -434,9 +436,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
prev = 0
|
||||
image_positions = []
|
||||
idefics_replacement = False
|
||||
|
||||
config.image_token_index = (
|
||||
config.image_token_index
|
||||
if hasattr(config, "image_token_index")
|
||||
else config.image_token_id
|
||||
)
|
||||
for image_id, id_token, tokens in image_tokens:
|
||||
id_token = tokenizer.get_vocab()[id_token]
|
||||
|
||||
if config.model_type == "idefics2" and idefics_replacement:
|
||||
id_token = config.image_token_index
|
||||
tokens = tokens[1:]
|
||||
|
||||
length = len(tokens)
|
||||
index = input_ids[prev:].index(id_token)
|
||||
|
||||
@ -463,6 +476,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image_positions.append(pos)
|
||||
prev = index + length
|
||||
|
||||
if config.model_type == "idefics2" and prev != len(input_ids):
|
||||
if input_ids[prev] == config.image_token_id:
|
||||
# means this is replaced by image_text_replacement_fixup
|
||||
idefics_replacement = True
|
||||
else:
|
||||
idefics_replacement = False
|
||||
|
||||
if len(image_positions) > 0:
|
||||
batch_image_positions[i] = image_positions
|
||||
|
||||
@ -540,40 +560,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
self.image_inputs[i][j] = None
|
||||
|
||||
if self.has_image and len(scheduled_image_pixel_values):
|
||||
self.pixel_values = torch.cat(
|
||||
[d["pixel_values"] for d in scheduled_image_pixel_values], dim=0
|
||||
).to(device)
|
||||
self.pixel_values = [
|
||||
d["pixel_values"].to(device) for d in scheduled_image_pixel_values
|
||||
]
|
||||
|
||||
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||
self.pixel_attention_mask = torch.cat(
|
||||
[d["pixel_attention_mask"] for d in scheduled_image_pixel_values],
|
||||
dim=0,
|
||||
).to(device)
|
||||
self.pixel_attention_mask = [
|
||||
d["pixel_attention_mask"].to(device)
|
||||
for d in scheduled_image_pixel_values
|
||||
]
|
||||
|
||||
if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||
self.image_sizes = torch.cat(
|
||||
[d["image_sizes"] for d in scheduled_image_pixel_values], dim=0
|
||||
).to(device)
|
||||
self.image_sizes = [
|
||||
d["image_sizes"].to(device) for d in scheduled_image_pixel_values
|
||||
]
|
||||
|
||||
if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||
self.image_grid_thw = torch.cat(
|
||||
[d["image_grid_thw"] for d in scheduled_image_pixel_values], dim=0
|
||||
).to(device)
|
||||
self.image_grid_thw = [
|
||||
d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
|
||||
]
|
||||
else:
|
||||
self.pixel_values = None
|
||||
self.pixel_attention_mask = None
|
||||
self.image_sizes = None
|
||||
self.image_grid_thw = None
|
||||
|
||||
def update_encoder_cache(self, encoder_outputs):
|
||||
prev = 0
|
||||
for i, input in self.scheduled_image_input:
|
||||
length = input.num_placeholder_tokens
|
||||
self.encoder_cache[i][input.id] = scatter_image_embeds(
|
||||
encoder_outputs[prev : prev + length], input.is_embed
|
||||
)
|
||||
|
||||
prev = prev + length
|
||||
def update_encoder_cache(self, encoder_outputs, batch, input):
|
||||
self.encoder_cache[batch][input.id] = scatter_image_embeds(
|
||||
encoder_outputs, input.is_embed
|
||||
)
|
||||
|
||||
def get_mm_embeddings(self):
|
||||
device = self.input_ids.device
|
||||
@ -675,22 +690,75 @@ class VlmCausalLM(FlashCausalLM):
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return self.batch_class
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_attention_mask: torch.Tensor,
|
||||
image_sizes: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor,
|
||||
):
|
||||
embeds = self.model.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
return embeds
|
||||
|
||||
def get_input_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return self.model.get_input_embeds(
|
||||
input_ids=input_ids,
|
||||
vision_embeds=vision_embeds,
|
||||
)
|
||||
|
||||
def get_mm_embeddings(self, batch):
|
||||
if batch.pixel_values is not None:
|
||||
encoder_outputs = self.get_vision_embeds(batch.pixel_values)
|
||||
batch.update_encoder_cache(encoder_outputs)
|
||||
batch.pixel_values = None
|
||||
for i, image_input in batch.scheduled_image_input:
|
||||
from pdb import set_trace
|
||||
|
||||
set_trace()
|
||||
pixel_values = batch.pixel_values[i]
|
||||
pixel_attention_mask = (
|
||||
batch.pixel_attention_mask[i]
|
||||
if batch.pixel_attention_mask is not None
|
||||
else None
|
||||
)
|
||||
image_sizes = (
|
||||
batch.image_sizes[i] if batch.image_sizes is not None else None
|
||||
)
|
||||
image_grid_thw = (
|
||||
batch.image_grid_thw[i]
|
||||
if batch.image_grid_thw is not None
|
||||
else None
|
||||
)
|
||||
|
||||
encoder_outputs = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
batch.update_encoder_cache(encoder_outputs, i, image_input)
|
||||
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch.get_mm_embeddings()
|
||||
|
||||
def get_input_embeddings(self, batch):
|
||||
if batch.has_image:
|
||||
vision_embeddings = self.get_mm_embeddings(batch)
|
||||
vision_embeds = self.get_mm_embeddings(batch)
|
||||
batch.has_image = False
|
||||
else:
|
||||
vision_embeddings = None
|
||||
vision_embeds = None
|
||||
|
||||
inputs_embeds = self.get_input_embeds(
|
||||
batch.input_ids, vision_embeddings=vision_embeddings
|
||||
batch.input_ids, vision_embeds=vision_embeds
|
||||
)
|
||||
|
||||
batch.inputs_embeds = inputs_embeds
|
||||
|
Loading…
Reference in New Issue
Block a user