diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 46657a7a..49617cbb 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -190,11 +190,13 @@ def image_text_replacement_fixup(config, text: str) -> str: ) return text + def preprocess_text(config, text: str) -> str: if config.model_type == "paligemma": return "" + text + "\n" return text + def preprocess_image(config, img): model_type = config.model_type @@ -203,10 +205,11 @@ def preprocess_image(config, img): elif model_type == "paligemma": img = img.convert("RGB") elif model_type not in {"llava_next", "gemma3", "llama4"}: - img = [img] + img = [img] return img + def get_unpadded_features( original_height: int, original_width: int, @@ -379,7 +382,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch_image_inputs: List[Optional[List[dict]]] = [] batch_image_positions: List[Optional[List[ImagePositions]]] = [] - for i, r in enumerate(requests): + for r in requests: text_parts = [] image_inputs = [] image_texts = [] @@ -457,16 +460,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch): img_text = img_text.replace("\n\n", "") tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"] + length = len(tokens) pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) index = img_start_token_pos[pos] - is_embed = torch.tensor(tokens) == config.image_token_index - num_placeholder_tokens = is_embed.sum().item() + assert ( + input_ids[index : index + length] == tokens + ), "Image tokens not found in input_ids" - length = len(tokens) + num_placeholder_tokens = tokens.count(config.image_token_index) if num_placeholder_tokens == length: is_embed = None + else: + is_embed = torch.as_tensor(tokens) == config.image_token_index pos = ImagePositions( offset=index,