mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix mllama
This commit is contained in:
parent
b7a1280f25
commit
f923a3fb68
@ -546,7 +546,7 @@ mod tests {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
input_length: 1,
|
||||
add_special_tokens: true,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
|
@ -297,6 +297,17 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
|
||||
if batch.pixel_values is not None:
|
||||
cross_attention_states = self.model.vision_forward(
|
||||
pixel_values=batch.pixel_values,
|
||||
aspect_ratio_ids=batch.aspect_ratio_ids,
|
||||
aspect_ratio_mask=batch.aspect_ratio_mask,
|
||||
)
|
||||
batch.cross_attention_states = cross_attention_states
|
||||
|
||||
cross_attention_states = batch.cross_attention_states
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -308,18 +319,14 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
cross_attention_states=cross_attention_states,
|
||||
adapter_data=adapter_data,
|
||||
image_indices=batch.image_indices[:],
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
|
Loading…
Reference in New Issue
Block a user