This commit is contained in:
Mohit Sharma 2025-04-19 10:26:56 +00:00
parent 526a8785ed
commit b86919a87a

View File

@ -290,6 +290,24 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.image_inputs = []
batch.image_positions = []
batch.encoder_cache = []
for b in batches:
if b.image_inputs is not None:
batch.image_inputs.extend(b.image_inputs)
else:
batch.image_inputs.append(None)
if b.image_positions is not None:
batch.image_positions.extend(b.image_positions)
else:
batch.image_positions.append(None)
if b.encoder_cache is not None:
batch.encoder_cache.extend(b.encoder_cache)
else:
batch.encoder_cache.append(None)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
@ -298,11 +316,28 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
image_inputs = []
image_positions = []
encoder_cache = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
image_inputs.append(self.image_inputs[idx])
image_positions.append(self.image_positions[idx])
encoder_cache.append(self.encoder_cache[idx])
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = encoder_cache
return batch
@classmethod
@ -352,7 +387,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if len(image_inputs) > 0:
batch_image_inputs[i] = image_inputs
from pdb import set_trace
set_trace()
batch_image_positions = []
batch_tokenized_inputs = []
max_length = 0
@ -459,6 +496,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
super().prepare_for_prefill()
self.has_image = False
self.encoder_cache_to_free = []
self.scheduled_image_input = []
scheduled_image_pixel_values = []
@ -574,6 +612,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
start_idx = max(cache_length - start_pos, 0)
end_idx = min(cache_length - start_pos + input_length, length)
if end_idx == length:
self.encoder_cache_to_free.append((i, image_id))
assert (
image_id in self.encoder_cache[i]
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
@ -592,35 +633,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return torch.cat(mm_embeds, dim=0).to(device)
def free_encoder_cache(self):
for i, (
r,
cache_length,
input_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
)
):
if not request_prefilling or self.image_positions[i] is None:
continue
for i, image_id in self.encoder_cache_to_free:
self.encoder_cache[i][image_id] = None
for j, image_position in enumerate(self.image_positions[i]):
image_id = image_position.id
start_pos = image_position.offset
length = image_position.length
cache_length = cache_length + input_length
if start_pos >= cache_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
self.encoder_cache[i][image_id] = None
self.encoder_cache_to_free = []
class VlmCausalLM(FlashCausalLM):
@ -814,6 +830,7 @@ class VlmCausalLM(FlashCausalLM):
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
batch.free_encoder_cache()
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph