This commit is contained in:
Mohit Sharma 2025-04-18 14:57:37 +00:00
parent 92909f3f33
commit 44ed5efbcc
3 changed files with 271 additions and 123 deletions

View File

@ -207,6 +207,8 @@ class FlashCausalLMBatch(Batch):
# Maximum number of blocks
max_blocks: int
inputs_embeds: Optional[torch.Tensor] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
@ -1896,6 +1898,8 @@ class FlashCausalLM(Model):
if prefill:
batch.prepare_for_prefill()
self.get_input_embeddings(batch)
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)

View File

@ -368,6 +368,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
@ -377,9 +378,12 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
)
inputs["input_ids"] = None
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
logits = self.model.original_forward(
input_ids=inputs["input_ids"],
inputs_embeds=inputs_embeds.unsqueeze(0),
position_ids=inputs["position_ids"],
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
@ -568,3 +572,44 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
inputs["cache_position"] = position_ids
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
return inputs
def get_vision_embeds(self, pixel_values, image_sizes=None):
image_features = self.model.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
vision_feature_select_strategy=self.model.config.vision_config.vision_feature_select_strategy,
image_sizes=image_sizes,
)
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
return projected_vision_flat
def get_input_embeds(self, input_ids, vision_embeddings=None):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if vision_embeddings is not None:
original_inputs_embeds_shape = inputs_embeds.shape
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != vision_embeddings.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {vision_embeddings.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
-1, inputs_embeds.size(-1)
)
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, vision_embeddings
)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
return inputs_embeds

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
import torch
from PIL import Image
from io import BytesIO
@ -115,7 +116,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
if processor.image_processor.do_image_splitting:
image_str *= 5
return image_str
return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
@ -132,7 +133,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
@ -141,32 +142,32 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "<image>" * num_features
return "<image>" * num_features, "<image>"
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
return "<image>" * config.text_config.num_image_tokens, "<image>"
elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
elif config.model_type == "llama4":
patch_size = config.vision_config.patch_size
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
aspect_ratios = image_input["aspect_ratios"][image_id]
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
aspect_ratios = image_input[image_id]["aspect_ratios"][0]
image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // patch_size)
@ -177,7 +178,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
aspect_ratios, num_patches_per_chunk
)
return tokens_for_this_image
return tokens_for_this_image, "<|image_start|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -244,7 +245,42 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features
def scatter_image_embeds(embeds, is_embed):
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_image_embeds(embeds, is_embed):
if is_embed is None:
return embeds
gathered = embeds[is_embed]
if len(gathered) == 0:
return None
return gathered
@dataclass
class ImagePositions:
offset: int
length: int
id: int
num_placeholder_tokens: int
is_embed: Optional[torch.Tensor] = None
class VlmCausalLMBatch(FlashCausalLMBatch):
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
image_positions: Optional[List[List[ImagePositions]]]
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@ -276,8 +312,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
for r in requests:
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
batch_image_inputs = []
for i, r in enumerate(requests):
image_inputs = []
batch_image_inputs.append(None)
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
@ -292,46 +337,54 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
h = image.height * 2
image = image.resize((w, h))
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
elif config.model_type == "llama4":
images.append(image)
if config.model_type in {"llava_next", "gemma3", "llama4"}:
image = image
else:
images.append([image])
image = [image]
pixel_values = processor.image_processor(
[image], return_tensors="pt", **kwargs
)
image_inputs.append(pixel_values)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else:
image_inputs = None
if len(image_inputs) > 0:
batch_image_inputs[i] = image_inputs
batch_image_positions = []
batch_tokenized_inputs = []
max_length = 0
image_id = 0
for r in requests:
for i, r in enumerate(requests):
full_text = ""
image_tokens = []
batch_image_positions.append(None)
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
full_text += image_text_replacement(
processor, image_inputs, config, image_id
image_text, id_token = image_text_replacement(
processor, batch_image_inputs[i], config, image_id
)
full_text += image_text
if config.model_type == "gemma3":
image_text = image_text.replace("\n\n", "")
image_tokens.append(
(
image_id,
id_token,
tokenizer(image_text, add_special_tokens=False)[
"input_ids"
],
)
)
image_id += 1
# from pdb import set_trace; set_trace()
full_text = image_text_replacement_fixup(config, full_text)
input_ids = tokenizer(
full_text,
@ -342,7 +395,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
return batch_tokenized_inputs, image_inputs
prev = 0
image_positions = []
for image_id, id_token, tokens in image_tokens:
id_token = tokenizer.get_vocab()[id_token]
length = len(tokens)
index = input_ids[prev:].index(id_token)
index = index + prev
assert (
input_ids[index : index + length] == tokens
), "Image token not found in input_ids"
if config.model_type in {"llava_next", "paligemma"}:
is_embed = None
num_placeholder_tokens = length
else:
is_embed = torch.tensor(tokens) == config.image_token_index
num_placeholder_tokens = is_embed.sum().item()
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
prev = index + length
if len(image_positions) > 0:
batch_image_positions[i] = image_positions
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod
def from_pb_processor(
@ -354,27 +441,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
batch_tokenized_inputs, image_inputs, image_positions = (
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
if len(image_inputs):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
@ -394,58 +468,77 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
)
):
if not request_prefilling:
if not request_prefilling or self.image_positions[i] is None:
continue
for mm_inputs in batch_mm_inputs[i]:
for j, mm_input in enumerate(mm_input):
image_id = mm_input.id
pixel_values = self.all_pixel_values[i][j].pixel_values
for j, image_position in enumerate(self.image_positions[i]):
image_id = image_position.id
pixel_values = self.image_inputs[i][j]
start_pos = mm_input.offset
length = mm_input.length
num_placeholder_tokens = mm_input.num_placeholder_tokens
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image = True
self.has_image = True
if image_id not in self.encoder_cache[i][image_id]:
self.scheduled_image_input.append((i, mm_input))
scheduled_image_pixel_values.append(pixel_values)
if image_id not in self.encoder_cache[i]:
self.scheduled_image_input.append((i, image_position))
scheduled_image_pixel_values.append(pixel_values)
if self.has_image and len(scheduled_image_pixel_values):
self.pixel_values = torch.cat([scheduled_image_pixel_values], dim=0).to(device)
self.pixel_values = torch.cat(
[d["pixel_values"] for d in scheduled_image_pixel_values], dim=0
).to(device)
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
self.pixel_attention_mask = torch.cat(
[d["pixel_attention_mask"] for d in scheduled_image_pixel_values],
dim=0,
).to(device)
if "image_sizes" in scheduled_image_pixel_values[0]:
self.image_sizes = torch.cat(
[d["image_sizes"] for d in scheduled_image_pixel_values], dim=0
).to(device)
if "image_grid_thw" in scheduled_image_pixel_values[0]:
self.image_grid_thw = torch.cat(
[d["image_grid_thw"] for d in scheduled_image_pixel_values], dim=0
).to(device)
else:
self.pixel_values = None
self.pixel_attention_mask = None
self.image_sizes = None
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs):
prev = 0
for i, input in self.scheduled_image_input:
length = input.num_placeholder_tokens
output = encoder_outputs[prev:length]
batch.encoder_cache[i][image_id] = self.scatter_image_embed(output, input.is_embed)
self.encoder_cache[i][input.id] = scatter_image_embeds(
encoder_outputs[prev : prev + length], input.is_embed
)
prev = length
prev = prev + length
def get_mm_embeddings(self):
device = self.input_ids.device
mm_embeds = []
for i, (
r,
cache_length,
@ -461,51 +554,43 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.prefilling_mask,
)
):
if not request_prefilling:
if not request_prefilling or self.image_positions[i] is None:
continue
for mm_inputs in batch_mm_inputs[i]:
for j, mm_input in enumerate(mm_input):
image_id = mm_input.id
pixel_values = self.all_pixel_values[i][j].pixel_values
for j, image_position in enumerate(self.image_positions[i]):
image_id = image_position.id
start_pos = mm_input.offset
length = mm_input.length
num_placeholder_tokens = mm_input.num_placeholder_tokens
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image = True
start_idx = max(cache_length - start_pos, 0)
end_idx = min(cache_length - start_pos + input_length, length)
if image_id not in self.encoder_cache[i][image_id]:
self.scheduled_image_input.append(mm_input)
scheduled_image_pixel_values.append(pixel_values)
assert (
image_id in self.encoder_cache[i]
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
encoder_output = self.encoder_cache[i][image_id]
is_embed = image_position.is_embed
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
start_idx = max(cache_length - start_pos, 0)
end_idx = min(
cache_length - start_pos + input_length,
length)
encoder_output = self.encoder_cache[i][image_id]
is_embed = pos_info.is_embed
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
mm_embeds_item = gather_image_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return torch.cat(mm_embeds, dim=0).to(device)
class VlmCausalLM(FlashCausalLM):
def __init__(
self,
@ -544,12 +629,24 @@ class VlmCausalLM(FlashCausalLM):
def get_mm_embeddings(self, batch):
if batch.pixel_values is not None:
encoder_outputs = self.model.get_mm_embeddings(batch.pixel_values)
encoder_outputs = self.get_vision_embeds(batch.pixel_values)
batch.update_encoder_cache(encoder_outputs)
batch.pixel_values = None
return batch.get_mm_embeddings()
def get_input_embeddings(self, batch):
if batch.has_image:
vision_embeddings = self.get_mm_embeddings(batch)
batch.has_image = False
else:
vision_embeddings = None
inputs_embeds = self.get_input_embeds(
batch.input_ids, vision_embeddings=vision_embeddings
)
batch.inputs_embeds = inputs_embeds
def forward(
self,
batch: VlmCausalLMBatch,
@ -600,6 +697,7 @@ class VlmCausalLM(FlashCausalLM):
position_ids = new_position_ids
else:
input_ids = batch.input_ids
inputs_embeds = batch.inputs_embeds
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
@ -659,6 +757,7 @@ class VlmCausalLM(FlashCausalLM):
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,