From 7ede61bca65d7b99111e6a79eef4a6ba847b6ed0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 1 Oct 2024 10:18:41 +0200 Subject: [PATCH] Force ignore all images but last. --- .../models/mllama_causal_lm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index ef12b621..9e19e171 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -91,9 +91,12 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): texts = [] image_indices = [] batch_tokenized_inputs = [] + for i, r in enumerate(requests): # Each input is encoded into a list, where each element of this input list is either a string or a URL curr_text = "" + curr_image = None + curr_i = None for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": @@ -103,11 +106,16 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): # TODO unsure about BOS curr_text += "<|image|>" image_input = processor.image_processor(image, return_tensors="pt") - image_inputs.append(image_input) - image_indices.append(i) + curr_image = image_input + curr_i = i + # image_inputs.append(image_input) + # image_indices.append(i) else: raise RuntimeError(f"Invalid chunk type {chunk_type}") texts.append(curr_text) + if curr_image is not None: + image_inputs.append(curr_image) + image_indices.append(curr_i) input_ids = tokenizer( curr_text,