add logic

This commit is contained in:
Mohit Sharma 2025-04-18 12:37:40 +05:30 committed by GitHub
parent 84ab88d843
commit 92909f3f33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -381,6 +381,130 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = None batch.image_grid_thw = None
return batch return batch
def prepare_for_prefill(self):
super().prepare_for_prefill()
self.has_image = False
self.scheduled_image_input = []
scheduled_image_pixel_values = []
device = self.input_ids.device
for i, (
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
)
):
if not request_prefilling:
continue
for mm_inputs in batch_mm_inputs[i]:
for j, mm_input in enumerate(mm_input):
image_id = mm_input.id
pixel_values = self.all_pixel_values[i][j].pixel_values
start_pos = mm_input.offset
length = mm_input.length
num_placeholder_tokens = mm_input.num_placeholder_tokens
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image = True
if image_id not in self.encoder_cache[i][image_id]:
self.scheduled_image_input.append((i, mm_input))
scheduled_image_pixel_values.append(pixel_values)
if self.has_image and len(scheduled_image_pixel_values):
self.pixel_values = torch.cat([scheduled_image_pixel_values], dim=0).to(device)
def update_encoder_cache(self, encoder_outputs):
prev = 0
for i, input in self.scheduled_image_input:
length = input.num_placeholder_tokens
output = encoder_outputs[prev:length]
batch.encoder_cache[i][image_id] = self.scatter_image_embed(output, input.is_embed)
prev = length
def get_mm_embeddings(self):
device = self.input_ids.device
for i, (
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
)
):
if not request_prefilling:
continue
for mm_inputs in batch_mm_inputs[i]:
for j, mm_input in enumerate(mm_input):
image_id = mm_input.id
pixel_values = self.all_pixel_values[i][j].pixel_values
start_pos = mm_input.offset
length = mm_input.length
num_placeholder_tokens = mm_input.num_placeholder_tokens
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image = True
if image_id not in self.encoder_cache[i][image_id]:
self.scheduled_image_input.append(mm_input)
scheduled_image_pixel_values.append(pixel_values)
start_idx = max(cache_length - start_pos, 0)
end_idx = min(
cache_length - start_pos + input_length,
length)
encoder_output = self.encoder_cache[i][image_id]
is_embed = pos_info.is_embed
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return torch.cat(mm_embeds, dim=0).to(device)
class VlmCausalLM(FlashCausalLM): class VlmCausalLM(FlashCausalLM):
def __init__( def __init__(
@ -418,6 +542,14 @@ class VlmCausalLM(FlashCausalLM):
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return self.batch_class return self.batch_class
def get_mm_embeddings(self, batch):
if batch.pixel_values is not None:
encoder_outputs = self.model.get_mm_embeddings(batch.pixel_values)
batch.update_encoder_cache(encoder_outputs)
return batch.get_mm_embeddings()
def forward( def forward(
self, self,
batch: VlmCausalLMBatch, batch: VlmCausalLMBatch,