diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 2b1e01df..0f3183df 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -381,6 +381,130 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.image_grid_thw = None return batch + def prepare_for_prefill(self): + super().prepare_for_prefill() + + self.has_image = False + self.scheduled_image_input = [] + scheduled_image_pixel_values = [] + + device = self.input_ids.device + + for i, ( + r, + cache_length, + input_length, + prompt_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prompt_lengths, + self.prefilling_mask, + ) + ): + if not request_prefilling: + continue + + for mm_inputs in batch_mm_inputs[i]: + for j, mm_input in enumerate(mm_input): + image_id = mm_input.id + pixel_values = self.all_pixel_values[i][j].pixel_values + + start_pos = mm_input.offset + length = mm_input.length + num_placeholder_tokens = mm_input.num_placeholder_tokens + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + self.has_image = True + + if image_id not in self.encoder_cache[i][image_id]: + self.scheduled_image_input.append((i, mm_input)) + scheduled_image_pixel_values.append(pixel_values) + + if self.has_image and len(scheduled_image_pixel_values): + self.pixel_values = torch.cat([scheduled_image_pixel_values], dim=0).to(device) + + + def update_encoder_cache(self, encoder_outputs): + prev = 0 + for i, input in self.scheduled_image_input: + length = input.num_placeholder_tokens + output = encoder_outputs[prev:length] + batch.encoder_cache[i][image_id] = self.scatter_image_embed(output, input.is_embed) + + prev = length + + def get_mm_embeddings(self): + device = self.input_ids.device + + for i, ( + r, + cache_length, + input_length, + prompt_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prompt_lengths, + self.prefilling_mask, + ) + ): + if not request_prefilling: + continue + + for mm_inputs in batch_mm_inputs[i]: + for j, mm_input in enumerate(mm_input): + image_id = mm_input.id + pixel_values = self.all_pixel_values[i][j].pixel_values + + start_pos = mm_input.offset + length = mm_input.length + num_placeholder_tokens = mm_input.num_placeholder_tokens + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + self.has_image = True + + if image_id not in self.encoder_cache[i][image_id]: + self.scheduled_image_input.append(mm_input) + scheduled_image_pixel_values.append(pixel_values) + + + start_idx = max(cache_length - start_pos, 0) + end_idx = min( + cache_length - start_pos + input_length, + length) + + encoder_output = self.encoder_cache[i][image_id] + + is_embed = pos_info.is_embed + if is_embed is not None: + is_embed = is_embed[start_idx:end_idx] + + mm_embeds_item = gather_mm_embeds( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + mm_embeds.append(mm_embeds_item) + + return torch.cat(mm_embeds, dim=0).to(device) class VlmCausalLM(FlashCausalLM): def __init__( @@ -418,6 +542,14 @@ class VlmCausalLM(FlashCausalLM): def batch_type(self) -> Type[VlmCausalLMBatch]: return self.batch_class + def get_mm_embeddings(self, batch): + if batch.pixel_values is not None: + encoder_outputs = self.model.get_mm_embeddings(batch.pixel_values) + + batch.update_encoder_cache(encoder_outputs) + + return batch.get_mm_embeddings() + def forward( self, batch: VlmCausalLMBatch,