optimizations

This commit is contained in:
Mohit Sharma 2025-04-22 07:49:45 +00:00
parent 2f67c53075
commit 6545cdde0d

View File

@ -190,11 +190,13 @@ def image_text_replacement_fixup(config, text: str) -> str:
)
return text
def preprocess_text(config, text: str) -> str:
if config.model_type == "paligemma":
return "<bos>" + text + "\n"
return text
def preprocess_image(config, img):
model_type = config.model_type
@ -207,6 +209,7 @@ def preprocess_image(config, img):
return img
def get_unpadded_features(
original_height: int,
original_width: int,
@ -379,7 +382,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch_image_inputs: List[Optional[List[dict]]] = []
batch_image_positions: List[Optional[List[ImagePositions]]] = []
for i, r in enumerate(requests):
for r in requests:
text_parts = []
image_inputs = []
image_texts = []
@ -457,16 +460,20 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
length = len(tokens)
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
is_embed = torch.tensor(tokens) == config.image_token_index
num_placeholder_tokens = is_embed.sum().item()
assert (
input_ids[index : index + length] == tokens
), "Image tokens not found in input_ids"
length = len(tokens)
num_placeholder_tokens = tokens.count(config.image_token_index)
if num_placeholder_tokens == length:
is_embed = None
else:
is_embed = torch.as_tensor(tokens) == config.image_token_index
pos = ImagePositions(
offset=index,