add encoder cache free

This commit is contained in:
Mohit Sharma 2025-04-18 16:00:35 +00:00
parent 44ed5efbcc
commit 526a8785ed
3 changed files with 37 additions and 0 deletions

View File

@ -1899,7 +1899,9 @@ class FlashCausalLM(Model):
batch.prepare_for_prefill() batch.prepare_for_prefill()
self.get_input_embeddings(batch) self.get_input_embeddings(batch)
from pdb import set_trace
set_trace()
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)

View File

@ -304,6 +304,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
support_chunking=True,
) )
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code

View File

@ -499,6 +499,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if image_id not in self.encoder_cache[i]: if image_id not in self.encoder_cache[i]:
self.scheduled_image_input.append((i, image_position)) self.scheduled_image_input.append((i, image_position))
scheduled_image_pixel_values.append(pixel_values) scheduled_image_pixel_values.append(pixel_values)
self.image_inputs[i][j] = None
if self.has_image and len(scheduled_image_pixel_values): if self.has_image and len(scheduled_image_pixel_values):
self.pixel_values = torch.cat( self.pixel_values = torch.cat(
@ -590,6 +591,37 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return torch.cat(mm_embeds, dim=0).to(device) return torch.cat(mm_embeds, dim=0).to(device)
def free_encoder_cache(self):
for i, (
r,
cache_length,
input_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
)
):
if not request_prefilling or self.image_positions[i] is None:
continue
for j, image_position in enumerate(self.image_positions[i]):
image_id = image_position.id
start_pos = image_position.offset
length = image_position.length
cache_length = cache_length + input_length
if start_pos >= cache_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
self.encoder_cache[i][image_id] = None
class VlmCausalLM(FlashCausalLM): class VlmCausalLM(FlashCausalLM):
def __init__( def __init__(
@ -831,4 +863,6 @@ class VlmCausalLM(FlashCausalLM):
else None else None
) )
logits = cuda_graph["logits"][:bs] logits = cuda_graph["logits"][:bs]
batch.free_encoder_cache()
return logits, speculative_logits return logits, speculative_logits