mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
Multi modality fix (#3283)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
6a2fa83540
commit
5284b5c654
@ -710,34 +710,41 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
# )
|
||||
if SYSTEM == "ipex":
|
||||
attn_output = torch.empty_like(query_states)
|
||||
ipex.llm.functional.varlen_attention(
|
||||
(
|
||||
query_states.contiguous()
|
||||
if query_states.device.type == "xpu"
|
||||
else query_states
|
||||
),
|
||||
(
|
||||
key_states.contiguous()
|
||||
if key_states.device.type == "xpu"
|
||||
else key_states
|
||||
),
|
||||
(
|
||||
value_states.contiguous()
|
||||
if value_states.device.type == "xpu"
|
||||
else value_states
|
||||
),
|
||||
attn_output,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
if query_states.device.type == "xpu":
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
attn_output,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||
query_states,
|
||||
|
@ -460,22 +460,41 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
# execute flash attention
|
||||
if SYSTEM == "ipex":
|
||||
attn_output = torch.empty_like(query)
|
||||
ipex.llm.functional.varlen_attention(
|
||||
(query.contiguous() if query.device.type == "xpu" else query),
|
||||
(key.contiguous() if key.device.type == "xpu" else key),
|
||||
(value.contiguous() if value.device.type == "xpu" else value),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
if query.device.type == "xpu":
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
|
@ -130,22 +130,41 @@ class Qwen2VLAttention(nn.Module):
|
||||
# execute flash attention
|
||||
if SYSTEM == "ipex":
|
||||
attn_output = torch.empty_like(query)
|
||||
ipex.llm.functional.varlen_attention(
|
||||
(query.contiguous() if query.device.type == "xpu" else query),
|
||||
(key.contiguous() if key.device.type == "xpu" else key),
|
||||
(value.contiguous() if value.device.type == "xpu" else value),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
if query.device.type == "xpu":
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
|
@ -59,7 +59,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
assert self.image_indices is not None
|
||||
batch = super().filter(request_ids)
|
||||
batch = super(VlmCausalLMBatch, self).filter(request_ids)
|
||||
assert self.image_indices is not None
|
||||
indices = []
|
||||
for i, request_id in enumerate(request_ids):
|
||||
@ -85,6 +85,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
]
|
||||
else:
|
||||
batch.cross_attention_states = None
|
||||
batch.pixel_values = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user