fix mllama

This commit is contained in:
OlivierDehaene 2024-10-10 16:01:18 +02:00
parent b7a1280f25
commit f923a3fb68
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 15 additions and 8 deletions

View File

@ -546,7 +546,7 @@ mod tests {
request: ValidGenerateRequest {
inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0,
input_length: 1,
add_special_tokens: true,
truncate: 0,
decoder_input_details: false,

View File

@ -297,6 +297,17 @@ class MllamaCausalLM(VlmCausalLM):
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -308,18 +319,14 @@ class MllamaCausalLM(VlmCausalLM):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices[:],
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph