fix qwen test

This commit is contained in:
Mohit Sharma 2025-04-30 10:08:55 +00:00
parent 996473164a
commit 6a5955a78c

View File

@ -603,6 +603,18 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.pixel_attention_mask = None self.pixel_attention_mask = None
self.image_sizes = None self.image_sizes = None
self.image_grid_thw = None self.image_grid_thw = None
else:
image_grid_thw_list = [
x[2]["image_grid_thw"]
for x in self.pixel_values
if "image_grid_thw" in x[2]
]
if image_grid_thw_list:
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
self.input_ids.device
)
else:
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, request_id, img_pos): def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds( self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
@ -886,7 +898,6 @@ class VlmCausalLM(FlashCausalLM):
) )
def encode_images(self, batch): def encode_images(self, batch):
image_grid_thw = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
device = batch.input_ids.device device = batch.input_ids.device
for request_id, image_id, image_input in batch.pixel_values: for request_id, image_id, image_input in batch.pixel_values:
@ -924,7 +935,6 @@ class VlmCausalLM(FlashCausalLM):
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = image_grid_thw
def set_inputs_embeds(self, batch): def set_inputs_embeds(self, batch):
if batch.has_image_inputs: if batch.has_image_inputs: