mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: integrate image tokens into inputs embeds
This commit is contained in:
parent
305db7ea1e
commit
a59b7faf0c
@ -112,9 +112,20 @@ pub struct ClipVisionModel {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Idefics3 {
|
pub struct Idefics3 {}
|
||||||
pub(crate) vision_encoder_max_image_size: usize,
|
|
||||||
pub(crate) image_seq_len: usize,
|
impl Idefics3 {
|
||||||
|
pub fn get_max_longest_edge(&self) -> usize {
|
||||||
|
364
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_number_of_features(&self) -> usize {
|
||||||
|
169
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
|
||||||
|
1456
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
@ -753,6 +753,19 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
# mask = input_ids == self.config.image_token_index
|
||||||
|
mask = input_ids == self.config.image_token_id
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -835,25 +848,22 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
all_states.append(image_hidden_states)
|
all_states.append(image_hidden_states)
|
||||||
image_hidden_states = torch.stack(all_states, dim=0)
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
# TODO: remove when prefill image tokens are handled correctly
|
||||||
# that simply don't exist
|
# * for now dummy tokens are added instead of the image tokens output byt the vision model
|
||||||
# TODO: finish implementing the image token replacement
|
mask_size = (input_ids == self.config.image_token_id).sum().item()
|
||||||
|
unrolled_image_size = (
|
||||||
|
image_hidden_states.shape[1] * image_hidden_states.shape[2]
|
||||||
|
)
|
||||||
|
diff = mask_size - unrolled_image_size
|
||||||
|
if diff > 0:
|
||||||
|
print(
|
||||||
|
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
|
||||||
|
)
|
||||||
|
|
||||||
# inputs_embeds = self.inputs_merger(
|
if mask_size == unrolled_image_size:
|
||||||
# input_ids=input_ids,
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
# inputs_embeds=inputs_embeds,
|
input_ids, inputs_embeds, image_hidden_states
|
||||||
# image_hidden_states=image_hidden_states,
|
)
|
||||||
# )
|
|
||||||
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# num_images, _, vision_hidden_size = image_hidden_states.shape
|
|
||||||
# special_image_token_mask = input_ids == self.image_token_id
|
|
||||||
# new_inputs_embeds = inputs_embeds.clone()
|
|
||||||
# reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to(
|
|
||||||
# inputs_embeds.dtype
|
|
||||||
# ) # cast to the dtype of the input_embeds to support quantized models
|
|
||||||
# new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
|
||||||
# inputs_embeds = new_inputs_embeds
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -23,6 +23,75 @@ tracer = trace.get_tracer(__name__)
|
|||||||
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||||
IDEFICS2_IMAGE_TOKEN = "<image>"
|
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||||
|
|
||||||
|
IDEFICS3_IMAGE_TOKEN = "<image>"
|
||||||
|
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
||||||
|
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_split_image(
|
||||||
|
image_seq_len,
|
||||||
|
image_rows,
|
||||||
|
image_cols,
|
||||||
|
fake_token_around_image,
|
||||||
|
image_token,
|
||||||
|
global_img_token,
|
||||||
|
):
|
||||||
|
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||||
|
text_split_images = ""
|
||||||
|
for n_h in range(image_rows):
|
||||||
|
for n_w in range(image_cols):
|
||||||
|
text_split_images += (
|
||||||
|
f"{fake_token_around_image}"
|
||||||
|
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
)
|
||||||
|
text_split_images += "\n"
|
||||||
|
|
||||||
|
text_split_images += (
|
||||||
|
f"\n{fake_token_around_image}"
|
||||||
|
+ f"{global_img_token}"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
+ f"{fake_token_around_image}"
|
||||||
|
)
|
||||||
|
return text_split_images
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_single_image(
|
||||||
|
image_seq_len, fake_token_around_image, image_token, global_img_token
|
||||||
|
):
|
||||||
|
"""Prompt with expanded image tokens for a single image."""
|
||||||
|
return (
|
||||||
|
f"{fake_token_around_image}"
|
||||||
|
+ f"{global_img_token}"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
+ f"{fake_token_around_image}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_prompt_string(
|
||||||
|
image_rows,
|
||||||
|
image_cols,
|
||||||
|
image_seq_len,
|
||||||
|
fake_token_around_image,
|
||||||
|
image_token,
|
||||||
|
global_img_token,
|
||||||
|
):
|
||||||
|
if image_rows == 0 and image_cols == 0:
|
||||||
|
return _prompt_single_image(
|
||||||
|
image_seq_len,
|
||||||
|
fake_token_around_image=fake_token_around_image,
|
||||||
|
image_token=image_token,
|
||||||
|
global_img_token=global_img_token,
|
||||||
|
)
|
||||||
|
return _prompt_split_image(
|
||||||
|
image_seq_len,
|
||||||
|
image_rows,
|
||||||
|
image_cols,
|
||||||
|
fake_token_around_image,
|
||||||
|
image_token,
|
||||||
|
global_img_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
image_str *= 5
|
image_str *= 5
|
||||||
return image_str
|
return image_str
|
||||||
if config.model_type == "idefics3":
|
if config.model_type == "idefics3":
|
||||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}"
|
# TODO: implement this in a more general way
|
||||||
image_str = ""
|
n_rows = image_input["rows"][0][image_id]
|
||||||
|
n_cols = image_input["cols"][0][image_id]
|
||||||
|
|
||||||
|
# TODO: avoid using hardcoded values
|
||||||
|
image_seq_len = 169 # default value
|
||||||
|
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
||||||
|
|
||||||
|
image_str = get_image_prompt_string(
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
image_seq_len,
|
||||||
|
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
|
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||||
|
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
|
)
|
||||||
return image_str
|
return image_str
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
@ -85,6 +168,10 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
|||||||
return text.replace(
|
return text.replace(
|
||||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||||
)
|
)
|
||||||
|
if config.model_type == "idefics3":
|
||||||
|
return text.replace(
|
||||||
|
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||||
|
)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@ -198,7 +285,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
if images:
|
if images:
|
||||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
image_inputs = processor.image_processor(
|
||||||
|
images, return_tensors="pt", return_row_col_info=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
|
|
||||||
@ -212,9 +301,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
full_text += chunk.text
|
full_text += chunk.text
|
||||||
elif chunk_type == "image":
|
elif chunk_type == "image":
|
||||||
full_text += image_text_replacement(
|
replacement_text = image_text_replacement(
|
||||||
processor, image_inputs, config, image_id
|
processor, image_inputs, config, image_id
|
||||||
)
|
)
|
||||||
|
full_text += replacement_text
|
||||||
image_id += 1
|
image_id += 1
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
@ -289,7 +379,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**processor_kwargs,
|
# **processor_kwargs,
|
||||||
)
|
)
|
||||||
self.batch_class = batch_class
|
self.batch_class = batch_class
|
||||||
# import ipdb; ipdb.set_trace()
|
# import ipdb; ipdb.set_trace()
|
||||||
|
Loading…
Reference in New Issue
Block a user