From 526a8785ed0c3870b76c5ef08161bc3235a55ab3 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 18 Apr 2025 16:00:35 +0000 Subject: [PATCH] add encoder cache free --- .../models/flash_causal_lm.py | 2 ++ .../models/transformers_flash_vlm.py | 1 + .../models/vlm_causal_lm.py | 34 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 15f1d73d..7064e13b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1899,7 +1899,9 @@ class FlashCausalLM(Model): batch.prepare_for_prefill() self.get_input_embeddings(batch) + from pdb import set_trace + set_trace() prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 89ef2a9b..0accaaa8 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -304,6 +304,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): device=device, rank=rank, world_size=world_size, + support_chunking=True, ) # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 178f736e..a4efbac4 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -499,6 +499,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if image_id not in self.encoder_cache[i]: self.scheduled_image_input.append((i, image_position)) scheduled_image_pixel_values.append(pixel_values) + self.image_inputs[i][j] = None if self.has_image and len(scheduled_image_pixel_values): self.pixel_values = torch.cat( @@ -590,6 +591,37 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return torch.cat(mm_embeds, dim=0).to(device) + def free_encoder_cache(self): + for i, ( + r, + cache_length, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, + ) + ): + if not request_prefilling or self.image_positions[i] is None: + continue + + for j, image_position in enumerate(self.image_positions[i]): + image_id = image_position.id + + start_pos = image_position.offset + length = image_position.length + + cache_length = cache_length + input_length + if start_pos >= cache_length: + # No encoder input required at this step + break + + if start_pos + length <= cache_length: + self.encoder_cache[i][image_id] = None + class VlmCausalLM(FlashCausalLM): def __init__( @@ -831,4 +863,6 @@ class VlmCausalLM(FlashCausalLM): else None ) logits = cuda_graph["logits"][:bs] + + batch.free_encoder_cache() return logits, speculative_logits