fix paligemma text

This commit is contained in:
Mohit Sharma 2025-04-30 07:47:14 +00:00
parent 61ccbf6bbd
commit 5cfd4b168a
6 changed files with 7 additions and 78 deletions

View File

@ -89,7 +89,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds.view(-1, vision_embeds.shape[-1])
inputs_embeds[mask] = vision_embeds
return inputs_embeds

View File

@ -940,7 +940,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None and len(vision_embeds) > 0:
if vision_embeds is not None:
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
return inputs_embeds

View File

@ -518,7 +518,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None and len(vision_embeds) > 0:
if vision_embeds is not None:
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
return inputs_embeds

View File

@ -1,71 +0,0 @@
from io import BytesIO
from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Iterable
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
image_text_replacement,
)
from text_generation_server.pb.generate_pb2 import Request
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
full_text = ""
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += "<bos>" + chunk.text + "\n"
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(
processor, image_input, config, image_id
)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs

View File

@ -202,7 +202,7 @@ def preprocess_image(config, img):
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
elif model_type == "paligemma":
if model_type == "paligemma":
img = img.convert("RGB")
if model_type not in {"llava_next", "gemma3", "llama4"}:
@ -432,7 +432,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
add_special_tokens=(
r.add_special_tokens if config.model_type != "paligemma" else False
),
)["input_ids"]
max_length = max(max_length, len(input_ids))

View File

@ -18,7 +18,6 @@ from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
@ -26,7 +25,6 @@ try:
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
VLM_BATCH_TYPES = {
PaliGemmaBatch,
VlmCausalLMBatch,
IdeficsCausalLMBatch,
MllamaCausalLMBatch,