From 5cfd4b168a69d5b2f871d0bf67102f9c7884d8a3 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 30 Apr 2025 07:47:14 +0000 Subject: [PATCH] fix paligemma text --- .../flash_pali_gemma_modeling.py | 2 +- .../models/custom_modeling/qwen2_5_vl.py | 2 +- .../models/custom_modeling/qwen2_vl.py | 2 +- .../models/pali_gemma.py | 71 ------------------- .../models/vlm_causal_lm.py | 6 +- server/text_generation_server/server.py | 2 - 6 files changed, 7 insertions(+), 78 deletions(-) delete mode 100644 server/text_generation_server/models/pali_gemma.py diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 0ea3a868..41af2b9e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -89,7 +89,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): if vision_embeds is not None: mask = input_ids == self.config.image_token_index - inputs_embeds[mask] = vision_embeds.view(-1, vision_embeds.shape[-1]) + inputs_embeds[mask] = vision_embeds return inputs_embeds diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index c1af3c28..e2fc60b1 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -940,7 +940,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): inputs_embeds = self.embed_tokens(input_ids) # apply the visual model to the pixel values if they are provided - if vision_embeds is not None and len(vision_embeds) > 0: + if vision_embeds is not None: inputs_embeds[input_ids == self.image_token_id] = vision_embeds return inputs_embeds diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 05d13786..75f718bd 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -518,7 +518,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): inputs_embeds = self.embed_tokens(input_ids) # apply the visual model to the pixel values if they are provided - if vision_embeds is not None and len(vision_embeds) > 0: + if vision_embeds is not None: inputs_embeds[input_ids == self.image_token_id] = vision_embeds return inputs_embeds diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py deleted file mode 100644 index fe75570e..00000000 --- a/server/text_generation_server/models/pali_gemma.py +++ /dev/null @@ -1,71 +0,0 @@ -from io import BytesIO -from PIL import Image -import torch -import torch.distributed -from opentelemetry import trace -from typing import Iterable -from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, - image_text_replacement, -) - -from text_generation_server.pb.generate_pb2 import Request - -tracer = trace.get_tracer(__name__) - - -class PaliGemmaBatch(VlmCausalLMBatch): - @classmethod - def batch_tokenized_inputs( - cls, requests: Iterable[Request], tokenizer, processor, config - ): - batch_inputs = [] - image_inputs = [] - max_truncation = 0 - for r in requests: - full_text = "" - image_id = 0 - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += "" + chunk.text + "\n" - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # TODO do_convert_RGB should be on by default ? - image = image.convert("RGB") - image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement( - processor, image_input, config, image_id - ) - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=False, - )["input_ids"] - if image_inputs: - image_input = image_inputs[0] - new_image_inputs = { - "pixel_values": torch.cat( - [img["pixel_values"] for img in image_inputs], dim=0 - ), - } - if "pixel_attention_mask" in image_input: - new_image_inputs["pixel_attention_mask"] = torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ) - if "image_sizes" in image_input: - new_image_inputs["image_sizes"] = torch.cat( - [img["image_sizes"] for img in image_inputs], dim=0 - ) - image_inputs = new_image_inputs - else: - image_inputs = None - return batch_tokenized_inputs, image_inputs diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 42ba15bf..8ac3d65a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -202,7 +202,7 @@ def preprocess_image(config, img): if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20: img = img.resize((img.width * 2, img.height * 2)) - elif model_type == "paligemma": + if model_type == "paligemma": img = img.convert("RGB") if model_type not in {"llava_next", "gemma3", "llama4"}: @@ -432,7 +432,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): full_text, truncation=True, max_length=r.truncate, - add_special_tokens=r.add_special_tokens, + add_special_tokens=( + r.add_special_tokens if config.model_type != "paligemma" else False + ), )["input_ids"] max_length = max(max_length, len(input_ids)) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 935e0985..87c51eb9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -18,7 +18,6 @@ from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: - from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, ) @@ -26,7 +25,6 @@ try: from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch VLM_BATCH_TYPES = { - PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch, MllamaCausalLMBatch,