rename vars

This commit is contained in:
Mohit Sharma 2025-04-22 13:54:39 +00:00
parent 63ddba24b4
commit f1da19df41

View File

@ -459,8 +459,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_positions = []
num_images = len(image_texts)
input_ids_t = torch.as_tensor(input_ids, dtype=torch.int32)
input_ids_t = torch.as_tensor(input_ids)
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
num_tokens = input_ids_t.numel()
last_pos = 0
for i in range(num_images):
@ -470,25 +471,25 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
length = len(tokens)
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
"input_ids"
][0]
length = tokens.numel()
assert length <= len(
input_ids
), f"{length} > {len(input_ids)} Image is truncated, try increasing --max-batch-prefill-tokens"
assert (
length <= num_tokens
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
assert (
input_ids[index : index + length] == tokens
assert torch.equal(
input_ids_t[index : index + length], tokens
), "Image tokens not found in input_ids"
num_placeholder_tokens = tokens.count(config.image_token_index)
is_embed = tokens == config.image_token_index
num_placeholder_tokens = int(is_embed.sum())
if num_placeholder_tokens == length:
is_embed = None
else:
is_embed = torch.as_tensor(tokens) == config.image_token_index
pos = ImagePositions(
offset=index,