mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into sliding_window
This commit is contained in:
commit
c2a8819edf
@ -26,8 +26,8 @@ class KVCache:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
dtype == torch.float8_e5m2
|
dtype == torch.float8_e5m2
|
||||||
and ATTENTION != "flashinfer"
|
and (ATTENTION != "flashinfer"
|
||||||
and SYSTEM != "cuda"
|
or SYSTEM != "cuda")
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
||||||
|
@ -19,7 +19,12 @@ from typing import Optional, Tuple, List
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import flash_attn_2_cuda
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
else:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -698,29 +703,60 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
# logger.info(
|
# logger.info(
|
||||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||||
# )
|
# )
|
||||||
attn_output = flash_attn_2_cuda.varlen_fwd(
|
if SYSTEM == "ipex":
|
||||||
query_states,
|
attn_output = torch.empty_like(query_states)
|
||||||
key_states,
|
ipex.llm.functional.varlen_attention(
|
||||||
value_states,
|
(
|
||||||
None,
|
query_states.contiguous()
|
||||||
cu_seqlen_q,
|
if query_states.device.type == "xpu"
|
||||||
cu_seqlen_k,
|
else query_states
|
||||||
None,
|
),
|
||||||
None,
|
(
|
||||||
None, # block_tables
|
key_states.contiguous()
|
||||||
None,
|
if key_states.device.type == "xpu"
|
||||||
max_q,
|
else key_states
|
||||||
max_k,
|
),
|
||||||
0.0,
|
(
|
||||||
self.softmax_scale,
|
value_states.contiguous()
|
||||||
False,
|
if value_states.device.type == "xpu"
|
||||||
causal, # Causal
|
else value_states
|
||||||
-1, # window_size_left,
|
),
|
||||||
-1,
|
attn_output,
|
||||||
0.0, # softcap
|
cu_seqlen_q,
|
||||||
False,
|
cu_seqlen_k,
|
||||||
None,
|
max_q,
|
||||||
)[0]
|
max_k,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
None,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # block_tables
|
||||||
|
None,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
causal, # Causal
|
||||||
|
-1, # window_size_left,
|
||||||
|
-1,
|
||||||
|
0.0, # softcap
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)[0]
|
||||||
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
Loading…
Reference in New Issue
Block a user