From 7237e8e6bfab9afff7609847f9d3fa7f52d2d794 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Sat, 19 Apr 2025 17:12:23 +0000 Subject: [PATCH] update pixel_values --- .../models/vlm_causal_lm.py | 91 +++++++++---------- 1 file changed, 44 insertions(+), 47 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ed2649b4..bcc67134 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -517,10 +517,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch): self.has_image = False self.encoder_cache_to_free = [] - self.scheduled_image_input = [] - scheduled_image_pixel_values = [] - device = self.input_ids.device + self.pixel_values = [] for i, ( r, @@ -540,7 +538,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): for j, image_position in enumerate(self.image_positions[i]): image_id = image_position.id - pixel_values = self.image_inputs[i][j] + image_inputs = self.image_inputs[i][j] start_pos = image_position.offset length = image_position.length @@ -555,38 +553,38 @@ class VlmCausalLMBatch(FlashCausalLMBatch): self.has_image = True if image_id not in self.encoder_cache[i]: - self.scheduled_image_input.append((i, image_position)) - scheduled_image_pixel_values.append(pixel_values) + self.pixel_values.append((i, image_position, image_inputs)) + # scheduled_image_pixel_values.append(image_inputs) self.image_inputs[i][j] = None - if self.has_image and len(scheduled_image_pixel_values): - self.pixel_values = [ - d["pixel_values"].to(device) for d in scheduled_image_pixel_values - ] + # if self.has_image and len(scheduled_image_pixel_values): + # self.pixel_values = [ + # d["pixel_values"].to(device) for d in scheduled_image_pixel_values + # ] - if "pixel_attention_mask" in scheduled_image_pixel_values[0]: - self.pixel_attention_mask = [ - d["pixel_attention_mask"].to(device) - for d in scheduled_image_pixel_values - ] + # if "pixel_attention_mask" in scheduled_image_pixel_values[0]: + # self.pixel_attention_mask = [ + # d["pixel_attention_mask"].to(device) + # for d in scheduled_image_pixel_values + # ] - if "image_sizes" in scheduled_image_pixel_values[0]: - self.image_sizes = [ - d["image_sizes"].to(device) for d in scheduled_image_pixel_values - ] + # if "image_sizes" in scheduled_image_pixel_values[0]: + # self.image_sizes = [ + # d["image_sizes"].to(device) for d in scheduled_image_pixel_values + # ] - if "image_grid_thw" in scheduled_image_pixel_values[0]: - self.image_grid_thw = [ - d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values - ] - else: + # if "image_grid_thw" in scheduled_image_pixel_values[0]: + # self.image_grid_thw = [ + # d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values + # ] + if not self.has_image: self.pixel_values = None self.pixel_attention_mask = None self.image_sizes = None self.image_grid_thw = None - def update_encoder_cache(self, encoder_outputs, batch, input): - self.encoder_cache[batch][input.id] = scatter_image_embeds( + def update_encoder_cache(self, encoder_outputs, request_id, input): + self.encoder_cache[request_id][input.id] = scatter_image_embeds( encoder_outputs, input.is_embed ) @@ -717,24 +715,26 @@ class VlmCausalLM(FlashCausalLM): def get_mm_embeddings(self, batch): if batch.pixel_values is not None: - for i, image_input in batch.scheduled_image_input: - from pdb import set_trace + device = batch.input_ids.device + for request_id, image_position, image_input in batch.pixel_values: + pixel_values = image_input["pixel_values"].to(device) - set_trace() - pixel_values = batch.pixel_values[i] - pixel_attention_mask = ( - batch.pixel_attention_mask[i] - if batch.pixel_attention_mask is not None - else None - ) - image_sizes = ( - batch.image_sizes[i] if batch.image_sizes is not None else None - ) - image_grid_thw = ( - batch.image_grid_thw[i] - if batch.image_grid_thw is not None - else None - ) + if "pixel_attention_mask" in image_input: + pixel_attention_mask = image_input["pixel_attention_mask"].to( + device + ) + else: + pixel_attention_mask = None + + if "image_sizes" in image_input: + image_sizes = image_input["image_sizes"].to(device) + else: + image_sizes = None + + if "image_grid_thw" in image_input: + image_grid_thw = image_input["image_grid_thw"].to(device) + else: + image_grid_thw = None encoder_outputs = self.get_vision_embeds( pixel_values=pixel_values, @@ -742,12 +742,9 @@ class VlmCausalLM(FlashCausalLM): image_sizes=image_sizes, image_grid_thw=image_grid_thw, ) - batch.update_encoder_cache(encoder_outputs, i, image_input) + batch.update_encoder_cache(encoder_outputs, request_id, image_position) batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - batch.image_grid_thw = None return batch.get_mm_embeddings() def get_input_embeddings(self, batch):