This commit is contained in:
Mohit Sharma 2025-04-22 01:40:42 +05:30 committed by GitHub
parent 6ed540b52f
commit 46ff016490
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -190,6 +190,22 @@ def image_text_replacement_fixup(config, text: str) -> str:
)
return text
def preprocess_text(config, text: str) -> str:
if config.model_type == "paligemma":
return "<bos>" + text + "\n"
return text
def preprocess_image(config, img):
model_type = config.model_type
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
elif model_type == "paligemma":
img = img.convert("RGB")
elif model_type not in {"llava_next", "gemma3", "llama4"}:
img = [img]
return img
def get_unpadded_features(
original_height: int,
@ -371,38 +387,29 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
text_parts.append(chunk.text)
continue
if chunk_type != "image":
raise RuntimeError(f"Invalid chunk type {chunk_type}")
text = preprocess_text(config, chunk.text)
text_parts.append(text)
elif chunk_type == "image":
img = Image.open(BytesIO(chunk.image.data))
if config.model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
if config.model_type in {"paligemma"}:
img = img.convert("RGB")
if config.model_type not in {"llava_next", "gemma3", "llama4"}:
img = [img]
img = preprocess_image(config, img)
image_input = processor.image_processor(
[img], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
img_text, id_token_str = image_text_replacement(
img_text, img_start_token_str = image_text_replacement(
processor, image_input, config, 0
)
text_parts.append(img_text)
image_texts.append([image_id, id_token_str, img_text])
image_texts.append([image_id, img_start_token_str, img_text])
image_id += 1
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
full_text = image_text_replacement_fixup(config, "".join(text_parts))
input_ids = tokenizer(