mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: integrate image tokens into inputs embeds
This commit is contained in:
parent
305db7ea1e
commit
a59b7faf0c
@ -112,9 +112,20 @@ pub struct ClipVisionModel {
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Idefics3 {
|
||||
pub(crate) vision_encoder_max_image_size: usize,
|
||||
pub(crate) image_seq_len: usize,
|
||||
pub struct Idefics3 {}
|
||||
|
||||
impl Idefics3 {
|
||||
pub fn get_max_longest_edge(&self) -> usize {
|
||||
364
|
||||
}
|
||||
|
||||
pub fn get_number_of_features(&self) -> usize {
|
||||
169
|
||||
}
|
||||
|
||||
pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
|
||||
1456
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
@ -753,6 +753,19 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
# mask = input_ids == self.config.image_token_index
|
||||
mask = input_ids == self.config.image_token_id
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -835,25 +848,22 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
# TODO: finish implementing the image token replacement
|
||||
# TODO: remove when prefill image tokens are handled correctly
|
||||
# * for now dummy tokens are added instead of the image tokens output byt the vision model
|
||||
mask_size = (input_ids == self.config.image_token_id).sum().item()
|
||||
unrolled_image_size = (
|
||||
image_hidden_states.shape[1] * image_hidden_states.shape[2]
|
||||
)
|
||||
diff = mask_size - unrolled_image_size
|
||||
if diff > 0:
|
||||
print(
|
||||
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
|
||||
)
|
||||
|
||||
# inputs_embeds = self.inputs_merger(
|
||||
# input_ids=input_ids,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# image_hidden_states=image_hidden_states,
|
||||
# )
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# num_images, _, vision_hidden_size = image_hidden_states.shape
|
||||
# special_image_token_mask = input_ids == self.image_token_id
|
||||
# new_inputs_embeds = inputs_embeds.clone()
|
||||
# reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to(
|
||||
# inputs_embeds.dtype
|
||||
# ) # cast to the dtype of the input_embeds to support quantized models
|
||||
# new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
||||
# inputs_embeds = new_inputs_embeds
|
||||
if mask_size == unrolled_image_size:
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -23,6 +23,75 @@ tracer = trace.get_tracer(__name__)
|
||||
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||
|
||||
IDEFICS3_IMAGE_TOKEN = "<image>"
|
||||
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||
|
||||
|
||||
def _prompt_split_image(
|
||||
image_seq_len,
|
||||
image_rows,
|
||||
image_cols,
|
||||
fake_token_around_image,
|
||||
image_token,
|
||||
global_img_token,
|
||||
):
|
||||
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||
text_split_images = ""
|
||||
for n_h in range(image_rows):
|
||||
for n_w in range(image_cols):
|
||||
text_split_images += (
|
||||
f"{fake_token_around_image}"
|
||||
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
)
|
||||
text_split_images += "\n"
|
||||
|
||||
text_split_images += (
|
||||
f"\n{fake_token_around_image}"
|
||||
+ f"{global_img_token}"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
+ f"{fake_token_around_image}"
|
||||
)
|
||||
return text_split_images
|
||||
|
||||
|
||||
def _prompt_single_image(
|
||||
image_seq_len, fake_token_around_image, image_token, global_img_token
|
||||
):
|
||||
"""Prompt with expanded image tokens for a single image."""
|
||||
return (
|
||||
f"{fake_token_around_image}"
|
||||
+ f"{global_img_token}"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
+ f"{fake_token_around_image}"
|
||||
)
|
||||
|
||||
|
||||
def get_image_prompt_string(
|
||||
image_rows,
|
||||
image_cols,
|
||||
image_seq_len,
|
||||
fake_token_around_image,
|
||||
image_token,
|
||||
global_img_token,
|
||||
):
|
||||
if image_rows == 0 and image_cols == 0:
|
||||
return _prompt_single_image(
|
||||
image_seq_len,
|
||||
fake_token_around_image=fake_token_around_image,
|
||||
image_token=image_token,
|
||||
global_img_token=global_img_token,
|
||||
)
|
||||
return _prompt_split_image(
|
||||
image_seq_len,
|
||||
image_rows,
|
||||
image_cols,
|
||||
fake_token_around_image,
|
||||
image_token,
|
||||
global_img_token,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
image_str *= 5
|
||||
return image_str
|
||||
if config.model_type == "idefics3":
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}"
|
||||
image_str = ""
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input["rows"][0][image_id]
|
||||
n_cols = image_input["cols"][0][image_id]
|
||||
|
||||
# TODO: avoid using hardcoded values
|
||||
image_seq_len = 169 # default value
|
||||
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
||||
|
||||
image_str = get_image_prompt_string(
|
||||
n_rows,
|
||||
n_cols,
|
||||
image_seq_len,
|
||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||
)
|
||||
return image_str
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
@ -85,6 +168,10 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
if config.model_type == "idefics3":
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
@ -198,7 +285,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if images:
|
||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
||||
image_inputs = processor.image_processor(
|
||||
images, return_tensors="pt", return_row_col_info=True
|
||||
)
|
||||
else:
|
||||
image_inputs = None
|
||||
|
||||
@ -212,9 +301,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
full_text += image_text_replacement(
|
||||
replacement_text = image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
full_text += replacement_text
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
@ -289,7 +379,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**processor_kwargs,
|
||||
# **processor_kwargs,
|
||||
)
|
||||
self.batch_class = batch_class
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
Loading…
Reference in New Issue
Block a user