diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 09cdb8a9..b76dbe68 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -603,6 +603,18 @@ class VlmCausalLMBatch(FlashCausalLMBatch): self.pixel_attention_mask = None self.image_sizes = None self.image_grid_thw = None + else: + image_grid_thw_list = [ + x[2]["image_grid_thw"] + for x in self.pixel_values + if "image_grid_thw" in x[2] + ] + if image_grid_thw_list: + self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to( + self.input_ids.device + ) + else: + self.image_grid_thw = None def update_encoder_cache(self, encoder_outputs, request_id, img_pos): self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds( @@ -886,7 +898,6 @@ class VlmCausalLM(FlashCausalLM): ) def encode_images(self, batch): - image_grid_thw = None if batch.pixel_values is not None: device = batch.input_ids.device for request_id, image_id, image_input in batch.pixel_values: @@ -924,7 +935,6 @@ class VlmCausalLM(FlashCausalLM): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None - batch.image_grid_thw = image_grid_thw def set_inputs_embeds(self, batch): if batch.has_image_inputs: