feat: support flash attention 2 in qwen2 vl vision blocks

This commit is contained in:
David Holtz 2024-11-04 16:12:56 +00:00
parent b1f9044d6c
commit 41dff3147d

View File

@ -22,9 +22,11 @@ from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex": if SYSTEM == "ipex":
pass import intel_extension_for_pytorch as ipex
else: else:
pass import flash_attn_2_cuda
import numpy as np
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import torch.nn.functional as F import torch.nn.functional as F
@ -66,7 +68,7 @@ def apply_rotary_pos_emb_vision(
return output return output
class Qwen2VLSdpaAttention(nn.Module): class Qwen2VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights): def __init__(self, *, prefix, config, weights):
super().__init__() super().__init__()
self.embed_dim = config.embed_dim // weights.process_group.size() self.embed_dim = config.embed_dim // weights.process_group.size()
@ -88,6 +90,7 @@ class Qwen2VLSdpaAttention(nn.Module):
weights=weights, weights=weights,
bias=True, bias=True,
) )
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward( def forward(
self, self,
@ -117,37 +120,60 @@ class Qwen2VLSdpaAttention(nn.Module):
0 0
) )
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# TODO: make use of existing RotatoryPositionEmbedding class
# create the attention mask # calc maximum sequence length for any batch
attention_mask = torch.zeros( max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
[1, hidden_state.shape[0], hidden_state.shape[0]], query = query.contiguous()
device=hidden_state.device, key = key.contiguous()
dtype=torch.bool, value = value.contiguous()
) causal = False
# TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it
# apply the cu_seqlens to the attention mask # execute flash attention
for i in range(1, len(cu_seqlens)): if SYSTEM == "ipex":
attention_mask[ attn_output = torch.empty_like(query)
..., ipex.llm.functional.varlen_attention(
cu_seqlens[i - 1] : cu_seqlens[i], (query.contiguous() if query.device.type == "xpu" else query),
cu_seqlens[i - 1] : cu_seqlens[i], (key.contiguous() if key.device.type == "xpu" else key),
] = True (value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
key,
value,
None, # tmp buffer (auto-allocated)
cu_seqlens, # cu_seqlens_q
cu_seqlens, # cu_seqlens_k
None, # max_seqlen_q (auto-computed)
None, # max_seqlen_k (auto-computed)
None, # block_tables
None, # broadcast_mask
max_seqlen, # max_seqlen
max_seqlen, # max_seqlen
0.0, # dropout_p
self.softmax_scale,
False, # zero_tensors
causal, # causal attention within each sequence
-1, # window_size_left
-1, # window_size_right
0.0, # softmax_cap
False, # deterministic
None, # rng_state
)[0]
# transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim) # reshape output to original dimensions
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
# apply attention
attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask, dropout_p=0.0
)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(hidden_state.shape[0], -1) attn_output = attn_output.reshape(hidden_state.shape[0], -1)
# TODO: prefer flash attention
attn_output = self.proj(attn_output) attn_output = self.proj(attn_output)
return attn_output return attn_output
@ -173,7 +199,7 @@ class Qwen2VLVisionMLP(nn.Module):
class Qwen2VLVisionBlock(nn.Module): class Qwen2VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.attn = Qwen2VLSdpaAttention( self.attn = Qwen2VLAttention(
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
config=config, config=config,
weights=weights, weights=weights,