From f1da19df41190a98f30d4ae0f89696099241b245 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 22 Apr 2025 13:54:39 +0000 Subject: [PATCH] rename vars --- .../models/vlm_causal_lm.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 4aca153e..17474eb4 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -459,8 +459,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image_positions = [] num_images = len(image_texts) - input_ids_t = torch.as_tensor(input_ids, dtype=torch.int32) + input_ids_t = torch.as_tensor(input_ids) img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0] + num_tokens = input_ids_t.numel() last_pos = 0 for i in range(num_images): @@ -470,25 +471,25 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if config.model_type == "gemma3": img_text = img_text.replace("\n\n", "") - tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"] - length = len(tokens) + tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ][0] + length = tokens.numel() - assert length <= len( - input_ids - ), f"{length} > {len(input_ids)} Image is truncated, try increasing --max-batch-prefill-tokens" + assert ( + length <= num_tokens + ), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens" pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) index = img_start_token_pos[pos] - - assert ( - input_ids[index : index + length] == tokens + assert torch.equal( + input_ids_t[index : index + length], tokens ), "Image tokens not found in input_ids" - num_placeholder_tokens = tokens.count(config.image_token_index) + is_embed = tokens == config.image_token_index + num_placeholder_tokens = int(is_embed.sum()) if num_placeholder_tokens == length: is_embed = None - else: - is_embed = torch.as_tensor(tokens) == config.image_token_index pos = ImagePositions( offset=index,