update pixel_values

This commit is contained in:
Mohit Sharma 2025-04-19 17:12:23 +00:00
parent 52e4186c2a
commit 7237e8e6bf

View File

@ -517,10 +517,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.has_image = False self.has_image = False
self.encoder_cache_to_free = [] self.encoder_cache_to_free = []
self.scheduled_image_input = []
scheduled_image_pixel_values = []
device = self.input_ids.device self.pixel_values = []
for i, ( for i, (
r, r,
@ -540,7 +538,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
for j, image_position in enumerate(self.image_positions[i]): for j, image_position in enumerate(self.image_positions[i]):
image_id = image_position.id image_id = image_position.id
pixel_values = self.image_inputs[i][j] image_inputs = self.image_inputs[i][j]
start_pos = image_position.offset start_pos = image_position.offset
length = image_position.length length = image_position.length
@ -555,38 +553,38 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.has_image = True self.has_image = True
if image_id not in self.encoder_cache[i]: if image_id not in self.encoder_cache[i]:
self.scheduled_image_input.append((i, image_position)) self.pixel_values.append((i, image_position, image_inputs))
scheduled_image_pixel_values.append(pixel_values) # scheduled_image_pixel_values.append(image_inputs)
self.image_inputs[i][j] = None self.image_inputs[i][j] = None
if self.has_image and len(scheduled_image_pixel_values): # if self.has_image and len(scheduled_image_pixel_values):
self.pixel_values = [ # self.pixel_values = [
d["pixel_values"].to(device) for d in scheduled_image_pixel_values # d["pixel_values"].to(device) for d in scheduled_image_pixel_values
] # ]
if "pixel_attention_mask" in scheduled_image_pixel_values[0]: # if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
self.pixel_attention_mask = [ # self.pixel_attention_mask = [
d["pixel_attention_mask"].to(device) # d["pixel_attention_mask"].to(device)
for d in scheduled_image_pixel_values # for d in scheduled_image_pixel_values
] # ]
if "image_sizes" in scheduled_image_pixel_values[0]: # if "image_sizes" in scheduled_image_pixel_values[0]:
self.image_sizes = [ # self.image_sizes = [
d["image_sizes"].to(device) for d in scheduled_image_pixel_values # d["image_sizes"].to(device) for d in scheduled_image_pixel_values
] # ]
if "image_grid_thw" in scheduled_image_pixel_values[0]: # if "image_grid_thw" in scheduled_image_pixel_values[0]:
self.image_grid_thw = [ # self.image_grid_thw = [
d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values # d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
] # ]
else: if not self.has_image:
self.pixel_values = None self.pixel_values = None
self.pixel_attention_mask = None self.pixel_attention_mask = None
self.image_sizes = None self.image_sizes = None
self.image_grid_thw = None self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, batch, input): def update_encoder_cache(self, encoder_outputs, request_id, input):
self.encoder_cache[batch][input.id] = scatter_image_embeds( self.encoder_cache[request_id][input.id] = scatter_image_embeds(
encoder_outputs, input.is_embed encoder_outputs, input.is_embed
) )
@ -717,24 +715,26 @@ class VlmCausalLM(FlashCausalLM):
def get_mm_embeddings(self, batch): def get_mm_embeddings(self, batch):
if batch.pixel_values is not None: if batch.pixel_values is not None:
for i, image_input in batch.scheduled_image_input: device = batch.input_ids.device
from pdb import set_trace for request_id, image_position, image_input in batch.pixel_values:
pixel_values = image_input["pixel_values"].to(device)
set_trace() if "pixel_attention_mask" in image_input:
pixel_values = batch.pixel_values[i] pixel_attention_mask = image_input["pixel_attention_mask"].to(
pixel_attention_mask = ( device
batch.pixel_attention_mask[i] )
if batch.pixel_attention_mask is not None else:
else None pixel_attention_mask = None
)
image_sizes = ( if "image_sizes" in image_input:
batch.image_sizes[i] if batch.image_sizes is not None else None image_sizes = image_input["image_sizes"].to(device)
) else:
image_grid_thw = ( image_sizes = None
batch.image_grid_thw[i]
if batch.image_grid_thw is not None if "image_grid_thw" in image_input:
else None image_grid_thw = image_input["image_grid_thw"].to(device)
) else:
image_grid_thw = None
encoder_outputs = self.get_vision_embeds( encoder_outputs = self.get_vision_embeds(
pixel_values=pixel_values, pixel_values=pixel_values,
@ -742,12 +742,9 @@ class VlmCausalLM(FlashCausalLM):
image_sizes=image_sizes, image_sizes=image_sizes,
image_grid_thw=image_grid_thw, image_grid_thw=image_grid_thw,
) )
batch.update_encoder_cache(encoder_outputs, i, image_input) batch.update_encoder_cache(encoder_outputs, request_id, image_position)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch.get_mm_embeddings() return batch.get_mm_embeddings()
def get_input_embeddings(self, batch): def get_input_embeddings(self, batch):