Merge branch 'main' into sliding_window

This commit is contained in:
Wang, Yi A 2024-10-08 10:00:48 -04:00
commit c2a8819edf
2 changed files with 62 additions and 26 deletions

View File

@ -26,8 +26,8 @@ class KVCache:
if ( if (
dtype == torch.float8_e5m2 dtype == torch.float8_e5m2
and ATTENTION != "flashinfer" and (ATTENTION != "flashinfer"
and SYSTEM != "cuda" or SYSTEM != "cuda")
): ):
raise ValueError( raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA" "float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"

View File

@ -19,7 +19,12 @@ from typing import Optional, Tuple, List
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
import flash_attn_2_cuda from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import flash_attn_2_cuda
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import torch.nn.functional as F import torch.nn.functional as F
@ -698,29 +703,60 @@ class MllamaTextCrossAttention(nn.Module):
# logger.info( # logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# ) # )
attn_output = flash_attn_2_cuda.varlen_fwd( if SYSTEM == "ipex":
query_states, attn_output = torch.empty_like(query_states)
key_states, ipex.llm.functional.varlen_attention(
value_states, (
None, query_states.contiguous()
cu_seqlen_q, if query_states.device.type == "xpu"
cu_seqlen_k, else query_states
None, ),
None, (
None, # block_tables key_states.contiguous()
None, if key_states.device.type == "xpu"
max_q, else key_states
max_k, ),
0.0, (
self.softmax_scale, value_states.contiguous()
False, if value_states.device.type == "xpu"
causal, # Causal else value_states
-1, # window_size_left, ),
-1, attn_output,
0.0, # softcap cu_seqlen_q,
False, cu_seqlen_k,
None, max_q,
)[0] max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query_states,
key_states,
value_states,
None,
cu_seqlen_q,
cu_seqlen_k,
None,
None,
None, # block_tables
None,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal, # Causal
-1, # window_size_left,
-1,
0.0, # softcap
False,
None,
)[0]
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return attn_output return attn_output