From 46ff01649077910f98d061110688be4a963b7960 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 22 Apr 2025 01:40:42 +0530 Subject: [PATCH] improve --- .../models/vlm_causal_lm.py | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5394e1c1..41218d7c 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -190,6 +190,22 @@ def image_text_replacement_fixup(config, text: str) -> str: ) return text +def preprocess_text(config, text: str) -> str: + if config.model_type == "paligemma": + return "" + text + "\n" + return text + +def preprocess_image(config, img): + model_type = config.model_type + + if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20: + img = img.resize((img.width * 2, img.height * 2)) + elif model_type == "paligemma": + img = img.convert("RGB") + elif model_type not in {"llava_next", "gemma3", "llama4"}: + img = [img] + + return img def get_unpadded_features( original_height: int, @@ -371,39 +387,30 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image_id = 0 for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": - text_parts.append(chunk.text) - continue + text = preprocess_text(config, chunk.text) + text_parts.append(text) + elif chunk_type == "image": + img = Image.open(BytesIO(chunk.image.data)) + img = preprocess_image(config, img) - if chunk_type != "image": + image_input = processor.image_processor( + [img], return_tensors="pt", **kwargs + ) + image_inputs.append(image_input) + + img_text, img_start_token_str = image_text_replacement( + processor, image_input, config, 0 + ) + text_parts.append(img_text) + + image_texts.append([image_id, img_start_token_str, img_text]) + image_id += 1 + else: raise RuntimeError(f"Invalid chunk type {chunk_type}") - img = Image.open(BytesIO(chunk.image.data)) - - if config.model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20: - img = img.resize((img.width * 2, img.height * 2)) - - if config.model_type in {"paligemma"}: - img = img.convert("RGB") - - if config.model_type not in {"llava_next", "gemma3", "llama4"}: - img = [img] - - image_input = processor.image_processor( - [img], return_tensors="pt", **kwargs - ) - image_inputs.append(image_input) - - img_text, id_token_str = image_text_replacement( - processor, image_input, config, 0 - ) - - text_parts.append(img_text) - - image_texts.append([image_id, id_token_str, img_text]) - image_id += 1 - full_text = image_text_replacement_fixup(config, "".join(text_parts)) input_ids = tokenizer( full_text,