update pixel_values

This commit is contained in:
Mohit Sharma 2025-04-19 17:12:23 +00:00
parent 52e4186c2a
commit 7237e8e6bf

View File

@ -517,10 +517,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.has_image = False
self.encoder_cache_to_free = []
self.scheduled_image_input = []
scheduled_image_pixel_values = []
device = self.input_ids.device
self.pixel_values = []
for i, (
r,
@ -540,7 +538,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
for j, image_position in enumerate(self.image_positions[i]):
image_id = image_position.id
pixel_values = self.image_inputs[i][j]
image_inputs = self.image_inputs[i][j]
start_pos = image_position.offset
length = image_position.length
@ -555,38 +553,38 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.has_image = True
if image_id not in self.encoder_cache[i]:
self.scheduled_image_input.append((i, image_position))
scheduled_image_pixel_values.append(pixel_values)
self.pixel_values.append((i, image_position, image_inputs))
# scheduled_image_pixel_values.append(image_inputs)
self.image_inputs[i][j] = None
if self.has_image and len(scheduled_image_pixel_values):
self.pixel_values = [
d["pixel_values"].to(device) for d in scheduled_image_pixel_values
]
# if self.has_image and len(scheduled_image_pixel_values):
# self.pixel_values = [
# d["pixel_values"].to(device) for d in scheduled_image_pixel_values
# ]
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
self.pixel_attention_mask = [
d["pixel_attention_mask"].to(device)
for d in scheduled_image_pixel_values
]
# if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
# self.pixel_attention_mask = [
# d["pixel_attention_mask"].to(device)
# for d in scheduled_image_pixel_values
# ]
if "image_sizes" in scheduled_image_pixel_values[0]:
self.image_sizes = [
d["image_sizes"].to(device) for d in scheduled_image_pixel_values
]
# if "image_sizes" in scheduled_image_pixel_values[0]:
# self.image_sizes = [
# d["image_sizes"].to(device) for d in scheduled_image_pixel_values
# ]
if "image_grid_thw" in scheduled_image_pixel_values[0]:
self.image_grid_thw = [
d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
]
else:
# if "image_grid_thw" in scheduled_image_pixel_values[0]:
# self.image_grid_thw = [
# d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
# ]
if not self.has_image:
self.pixel_values = None
self.pixel_attention_mask = None
self.image_sizes = None
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, batch, input):
self.encoder_cache[batch][input.id] = scatter_image_embeds(
def update_encoder_cache(self, encoder_outputs, request_id, input):
self.encoder_cache[request_id][input.id] = scatter_image_embeds(
encoder_outputs, input.is_embed
)
@ -717,24 +715,26 @@ class VlmCausalLM(FlashCausalLM):
def get_mm_embeddings(self, batch):
if batch.pixel_values is not None:
for i, image_input in batch.scheduled_image_input:
from pdb import set_trace
device = batch.input_ids.device
for request_id, image_position, image_input in batch.pixel_values:
pixel_values = image_input["pixel_values"].to(device)
set_trace()
pixel_values = batch.pixel_values[i]
pixel_attention_mask = (
batch.pixel_attention_mask[i]
if batch.pixel_attention_mask is not None
else None
)
image_sizes = (
batch.image_sizes[i] if batch.image_sizes is not None else None
)
image_grid_thw = (
batch.image_grid_thw[i]
if batch.image_grid_thw is not None
else None
)
if "pixel_attention_mask" in image_input:
pixel_attention_mask = image_input["pixel_attention_mask"].to(
device
)
else:
pixel_attention_mask = None
if "image_sizes" in image_input:
image_sizes = image_input["image_sizes"].to(device)
else:
image_sizes = None
if "image_grid_thw" in image_input:
image_grid_thw = image_input["image_grid_thw"].to(device)
else:
image_grid_thw = None
encoder_outputs = self.get_vision_embeds(
pixel_values=pixel_values,
@ -742,12 +742,9 @@ class VlmCausalLM(FlashCausalLM):
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
batch.update_encoder_cache(encoder_outputs, i, image_input)
batch.update_encoder_cache(encoder_outputs, request_id, image_position)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch.get_mm_embeddings()
def get_input_embeddings(self, batch):