mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
fix paligemma text
This commit is contained in:
parent
61ccbf6bbd
commit
5cfd4b168a
@ -89,7 +89,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
if vision_embeds is not None:
|
if vision_embeds is not None:
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
inputs_embeds[mask] = vision_embeds.view(-1, vision_embeds.shape[-1])
|
inputs_embeds[mask] = vision_embeds
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
@ -940,7 +940,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
# apply the visual model to the pixel values if they are provided
|
||||||
if vision_embeds is not None and len(vision_embeds) > 0:
|
if vision_embeds is not None:
|
||||||
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -518,7 +518,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
# apply the visual model to the pixel values if they are provided
|
||||||
if vision_embeds is not None and len(vision_embeds) > 0:
|
if vision_embeds is not None:
|
||||||
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -1,71 +0,0 @@
|
|||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Iterable
|
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
|
||||||
VlmCausalLMBatch,
|
|
||||||
image_text_replacement,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.pb.generate_pb2 import Request
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
|
||||||
@classmethod
|
|
||||||
def batch_tokenized_inputs(
|
|
||||||
cls, requests: Iterable[Request], tokenizer, processor, config
|
|
||||||
):
|
|
||||||
batch_inputs = []
|
|
||||||
image_inputs = []
|
|
||||||
max_truncation = 0
|
|
||||||
for r in requests:
|
|
||||||
full_text = ""
|
|
||||||
image_id = 0
|
|
||||||
for chunk in r.input_chunks.chunks:
|
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
|
||||||
if chunk_type == "text":
|
|
||||||
full_text += "<bos>" + chunk.text + "\n"
|
|
||||||
elif chunk_type == "image":
|
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
|
||||||
# TODO do_convert_RGB should be on by default ?
|
|
||||||
image = image.convert("RGB")
|
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
|
||||||
full_text += image_text_replacement(
|
|
||||||
processor, image_input, config, image_id
|
|
||||||
)
|
|
||||||
image_inputs.append(image_input)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
|
||||||
batch_inputs,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_truncation,
|
|
||||||
add_special_tokens=False,
|
|
||||||
)["input_ids"]
|
|
||||||
if image_inputs:
|
|
||||||
image_input = image_inputs[0]
|
|
||||||
new_image_inputs = {
|
|
||||||
"pixel_values": torch.cat(
|
|
||||||
[img["pixel_values"] for img in image_inputs], dim=0
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if "pixel_attention_mask" in image_input:
|
|
||||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
||||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
if "image_sizes" in image_input:
|
|
||||||
new_image_inputs["image_sizes"] = torch.cat(
|
|
||||||
[img["image_sizes"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
image_inputs = new_image_inputs
|
|
||||||
else:
|
|
||||||
image_inputs = None
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
|
@ -202,7 +202,7 @@ def preprocess_image(config, img):
|
|||||||
|
|
||||||
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
||||||
img = img.resize((img.width * 2, img.height * 2))
|
img = img.resize((img.width * 2, img.height * 2))
|
||||||
elif model_type == "paligemma":
|
if model_type == "paligemma":
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB")
|
||||||
|
|
||||||
if model_type not in {"llava_next", "gemma3", "llama4"}:
|
if model_type not in {"llava_next", "gemma3", "llama4"}:
|
||||||
@ -432,7 +432,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
full_text,
|
full_text,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=r.truncate,
|
max_length=r.truncate,
|
||||||
add_special_tokens=r.add_special_tokens,
|
add_special_tokens=(
|
||||||
|
r.add_special_tokens if config.model_type != "paligemma" else False
|
||||||
|
),
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
max_length = max(max_length, len(input_ids))
|
max_length = max(max_length, len(input_ids))
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ from text_generation_server.utils.adapter import AdapterInfo
|
|||||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
@ -26,7 +25,6 @@ try:
|
|||||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
|
|
||||||
VLM_BATCH_TYPES = {
|
VLM_BATCH_TYPES = {
|
||||||
PaliGemmaBatch,
|
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
MllamaCausalLMBatch,
|
MllamaCausalLMBatch,
|
||||||
|
Loading…
Reference in New Issue
Block a user