mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix mllama
This commit is contained in:
parent
b7a1280f25
commit
f923a3fb68
@ -546,7 +546,7 @@ mod tests {
|
|||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
input_ids: Some(Arc::new(vec![])),
|
input_ids: Some(Arc::new(vec![])),
|
||||||
input_length: 0,
|
input_length: 1,
|
||||||
add_special_tokens: true,
|
add_special_tokens: true,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
|
@ -297,6 +297,17 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
max_q=batch.max_input_length,
|
max_q=batch.max_input_length,
|
||||||
max_k=batch.max_current_length,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
cross_attention_states = self.model.vision_forward(
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
aspect_ratio_ids=batch.aspect_ratio_ids,
|
||||||
|
aspect_ratio_mask=batch.aspect_ratio_mask,
|
||||||
|
)
|
||||||
|
batch.cross_attention_states = cross_attention_states
|
||||||
|
|
||||||
|
cross_attention_states = batch.cross_attention_states
|
||||||
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -308,18 +319,14 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
cross_attention_states=cross_attention_states,
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
adapter_data=adapter_data,
|
||||||
image_sizes=batch.image_sizes,
|
image_indices=batch.image_indices[:],
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
if batch.pixel_attention_mask is not None:
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
if batch.image_sizes is not None:
|
|
||||||
batch.image_sizes = None
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
Loading…
Reference in New Issue
Block a user