fix qwen test

This commit is contained in:
Mohit Sharma 2025-04-30 09:57:22 +00:00
parent d1cf64abc4
commit 996473164a

View File

@ -886,6 +886,7 @@ class VlmCausalLM(FlashCausalLM):
)
def encode_images(self, batch):
image_grid_thw = None
if batch.pixel_values is not None:
device = batch.input_ids.device
for request_id, image_id, image_input in batch.pixel_values:
@ -923,7 +924,7 @@ class VlmCausalLM(FlashCausalLM):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.image_grid_thw = image_grid_thw
def set_inputs_embeds(self, batch):
if batch.has_image_inputs:
@ -1066,6 +1067,7 @@ class VlmCausalLM(FlashCausalLM):
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
batch.image_grid_thw = None
batch.free_encoder_cache()
return logits, speculative_logits