diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index a4efbac4..717e8bbb 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -290,6 +290,24 @@ class VlmCausalLMBatch(FlashCausalLMBatch): @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): batch = super(VlmCausalLMBatch, cls).concatenate(batches) + + batch.image_inputs = [] + batch.image_positions = [] + batch.encoder_cache = [] + for b in batches: + if b.image_inputs is not None: + batch.image_inputs.extend(b.image_inputs) + else: + batch.image_inputs.append(None) + if b.image_positions is not None: + batch.image_positions.extend(b.image_positions) + else: + batch.image_positions.append(None) + if b.encoder_cache is not None: + batch.encoder_cache.extend(b.encoder_cache) + else: + batch.encoder_cache.append(None) + batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None @@ -298,11 +316,28 @@ class VlmCausalLMBatch(FlashCausalLMBatch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + + image_inputs = [] + image_positions = [] + encoder_cache = [] + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + image_inputs.append(self.image_inputs[idx]) + image_positions.append(self.image_positions[idx]) + encoder_cache.append(self.encoder_cache[idx]) + batch = super().filter(request_ids) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = encoder_cache return batch @classmethod @@ -352,7 +387,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if len(image_inputs) > 0: batch_image_inputs[i] = image_inputs + from pdb import set_trace + set_trace() batch_image_positions = [] batch_tokenized_inputs = [] max_length = 0 @@ -459,6 +496,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): super().prepare_for_prefill() self.has_image = False + self.encoder_cache_to_free = [] self.scheduled_image_input = [] scheduled_image_pixel_values = [] @@ -574,6 +612,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): start_idx = max(cache_length - start_pos, 0) end_idx = min(cache_length - start_pos + input_length, length) + if end_idx == length: + self.encoder_cache_to_free.append((i, image_id)) + assert ( image_id in self.encoder_cache[i] ), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}" @@ -592,35 +633,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return torch.cat(mm_embeds, dim=0).to(device) def free_encoder_cache(self): - for i, ( - r, - cache_length, - input_length, - request_prefilling, - ) in enumerate( - zip( - self.requests, - self.cache_lengths, - self.input_lengths, - self.prefilling_mask, - ) - ): - if not request_prefilling or self.image_positions[i] is None: - continue + for i, image_id in self.encoder_cache_to_free: + self.encoder_cache[i][image_id] = None - for j, image_position in enumerate(self.image_positions[i]): - image_id = image_position.id - - start_pos = image_position.offset - length = image_position.length - - cache_length = cache_length + input_length - if start_pos >= cache_length: - # No encoder input required at this step - break - - if start_pos + length <= cache_length: - self.encoder_cache[i][image_id] = None + self.encoder_cache_to_free = [] class VlmCausalLM(FlashCausalLM): @@ -814,6 +830,7 @@ class VlmCausalLM(FlashCausalLM): batch.image_sizes = None if batch.image_grid_thw is not None: batch.image_grid_thw = None + batch.free_encoder_cache() return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph