mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
improve
This commit is contained in:
parent
6ed540b52f
commit
46ff016490
@ -190,6 +190,22 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
|||||||
)
|
)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def preprocess_text(config, text: str) -> str:
|
||||||
|
if config.model_type == "paligemma":
|
||||||
|
return "<bos>" + text + "\n"
|
||||||
|
return text
|
||||||
|
|
||||||
|
def preprocess_image(config, img):
|
||||||
|
model_type = config.model_type
|
||||||
|
|
||||||
|
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
||||||
|
img = img.resize((img.width * 2, img.height * 2))
|
||||||
|
elif model_type == "paligemma":
|
||||||
|
img = img.convert("RGB")
|
||||||
|
elif model_type not in {"llava_next", "gemma3", "llama4"}:
|
||||||
|
img = [img]
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
def get_unpadded_features(
|
def get_unpadded_features(
|
||||||
original_height: int,
|
original_height: int,
|
||||||
@ -371,39 +387,30 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image_id = 0
|
image_id = 0
|
||||||
|
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
|
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
text_parts.append(chunk.text)
|
text = preprocess_text(config, chunk.text)
|
||||||
continue
|
text_parts.append(text)
|
||||||
|
elif chunk_type == "image":
|
||||||
|
img = Image.open(BytesIO(chunk.image.data))
|
||||||
|
img = preprocess_image(config, img)
|
||||||
|
|
||||||
if chunk_type != "image":
|
image_input = processor.image_processor(
|
||||||
|
[img], return_tensors="pt", **kwargs
|
||||||
|
)
|
||||||
|
image_inputs.append(image_input)
|
||||||
|
|
||||||
|
img_text, img_start_token_str = image_text_replacement(
|
||||||
|
processor, image_input, config, 0
|
||||||
|
)
|
||||||
|
text_parts.append(img_text)
|
||||||
|
|
||||||
|
image_texts.append([image_id, img_start_token_str, img_text])
|
||||||
|
image_id += 1
|
||||||
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
img = Image.open(BytesIO(chunk.image.data))
|
|
||||||
|
|
||||||
if config.model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
|
||||||
img = img.resize((img.width * 2, img.height * 2))
|
|
||||||
|
|
||||||
if config.model_type in {"paligemma"}:
|
|
||||||
img = img.convert("RGB")
|
|
||||||
|
|
||||||
if config.model_type not in {"llava_next", "gemma3", "llama4"}:
|
|
||||||
img = [img]
|
|
||||||
|
|
||||||
image_input = processor.image_processor(
|
|
||||||
[img], return_tensors="pt", **kwargs
|
|
||||||
)
|
|
||||||
image_inputs.append(image_input)
|
|
||||||
|
|
||||||
img_text, id_token_str = image_text_replacement(
|
|
||||||
processor, image_input, config, 0
|
|
||||||
)
|
|
||||||
|
|
||||||
text_parts.append(img_text)
|
|
||||||
|
|
||||||
image_texts.append([image_id, id_token_str, img_text])
|
|
||||||
image_id += 1
|
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, "".join(text_parts))
|
full_text = image_text_replacement_fixup(config, "".join(text_parts))
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
full_text,
|
full_text,
|
||||||
|
Loading…
Reference in New Issue
Block a user