From f923a3fb6823069a984eda07f1e086f83363ce07 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:01:18 +0200 Subject: [PATCH] fix mllama --- backends/v3/src/queue.rs | 2 +- .../models/mllama_causal_lm.py | 21 ++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 36fbed87..414045a1 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -546,7 +546,7 @@ mod tests { request: ValidGenerateRequest { inputs: vec![], input_ids: Some(Arc::new(vec![])), - input_length: 0, + input_length: 1, add_special_tokens: true, truncate: 0, decoder_input_details: false, diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 83e44039..6399f92c 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -297,6 +297,17 @@ class MllamaCausalLM(VlmCausalLM): max_q=batch.max_input_length, max_k=batch.max_current_length, ) + + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, + ) + batch.cross_attention_states = cross_attention_states + + cross_attention_states = batch.cross_attention_states + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -308,18 +319,14 @@ class MllamaCausalLM(VlmCausalLM): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, + cross_attention_states=cross_attention_states, + adapter_data=adapter_data, + image_indices=batch.image_indices[:], ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph