From 6a5955a78cdb8d001ba456d0071f60aad161b90e Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 30 Apr 2025 10:08:55 +0000 Subject: [PATCH] fix qwen test --- .../text_generation_server/models/vlm_causal_lm.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 09cdb8a9..b76dbe68 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -603,6 +603,18 @@ class VlmCausalLMBatch(FlashCausalLMBatch): self.pixel_attention_mask = None self.image_sizes = None self.image_grid_thw = None + else: + image_grid_thw_list = [ + x[2]["image_grid_thw"] + for x in self.pixel_values + if "image_grid_thw" in x[2] + ] + if image_grid_thw_list: + self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to( + self.input_ids.device + ) + else: + self.image_grid_thw = None def update_encoder_cache(self, encoder_outputs, request_id, img_pos): self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds( @@ -886,7 +898,6 @@ class VlmCausalLM(FlashCausalLM): ) def encode_images(self, batch): - image_grid_thw = None if batch.pixel_values is not None: device = batch.input_ids.device for request_id, image_id, image_input in batch.pixel_values: @@ -924,7 +935,6 @@ class VlmCausalLM(FlashCausalLM): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None - batch.image_grid_thw = image_grid_thw def set_inputs_embeds(self, batch): if batch.has_image_inputs: