fix mllama

This commit is contained in:
OlivierDehaene 2024-10-10 16:01:18 +02:00
parent b7a1280f25
commit f923a3fb68
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 15 additions and 8 deletions

View File

@ -546,7 +546,7 @@ mod tests {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])), input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 1,
add_special_tokens: true, add_special_tokens: true,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,

View File

@ -297,6 +297,17 @@ class MllamaCausalLM(VlmCausalLM):
max_q=batch.max_input_length, max_q=batch.max_input_length,
max_k=batch.max_current_length, max_k=batch.max_current_length,
) )
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -308,18 +319,14 @@ class MllamaCausalLM(VlmCausalLM):
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, cross_attention_states=cross_attention_states,
pixel_attention_mask=batch.pixel_attention_mask, adapter_data=adapter_data,
image_sizes=batch.image_sizes, image_indices=batch.image_indices[:],
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph