mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: prefer version check over test op and avoid window_size_left if not flash attn2
This commit is contained in:
parent
5123925101
commit
cae28dcbf1
@ -3,7 +3,6 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
@ -172,43 +171,11 @@ def paged_attention(
|
||||
|
||||
|
||||
try:
|
||||
if major <= 8:
|
||||
raise ImportError("Flash Attention V2 requires CUDA 11.0 or higher")
|
||||
|
||||
import flash_attn_2_cuda
|
||||
|
||||
# try forwarding to see if it works with all dummy inputs
|
||||
batch_size = 1
|
||||
num_heads = 1
|
||||
head_dim = 1
|
||||
seqlen = 1
|
||||
|
||||
try:
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # q
|
||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # k
|
||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # v
|
||||
None, # out (optional)
|
||||
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q
|
||||
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k
|
||||
None, # alibi_slopes (optional)
|
||||
None, # q_padded (optional)
|
||||
None, # k_padded (optional)
|
||||
None, # v_padded (optional)
|
||||
seqlen, # max_seqlen_q
|
||||
seqlen, # max_seqlen_k
|
||||
1.0, # softmax_scale
|
||||
0.0, # softmax_lse (default value)
|
||||
False, # is_causal
|
||||
True, # return_softmax
|
||||
-1, # window_size_left
|
||||
-1, # window_size_right
|
||||
0.0, # softmax_softcap
|
||||
False, # deterministic
|
||||
None, # rng_state (optional)
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not supported on this machine. " f"Error: {e}"
|
||||
) from e
|
||||
|
||||
V2 = True
|
||||
except ImportError:
|
||||
try:
|
||||
@ -289,10 +256,9 @@ else:
|
||||
window_size_left=-1,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
warnings.warn(
|
||||
"window_size_left is only available with flash attn v2. It will be ignored.",
|
||||
UserWarning,
|
||||
if window_size_left is not None and window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
|
||||
if softcap is not None:
|
||||
@ -341,6 +307,3 @@ else:
|
||||
0,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTS_WINDOWING = True
|
||||
|
@ -484,10 +484,14 @@ def get_model(
|
||||
)
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
|
||||
is_max_input_within_sliding_window = (
|
||||
max_input_tokens <= sliding_window if max_input_tokens is not None else False
|
||||
)
|
||||
|
||||
if (
|
||||
(sliding_window is not None and sliding_window != -1)
|
||||
and not SUPPORTS_WINDOWING
|
||||
and max_input_tokens > sliding_window
|
||||
and is_max_input_within_sliding_window
|
||||
):
|
||||
raise ValueError(
|
||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -110,9 +111,7 @@ class MistralConfig(PretrainedConfig):
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
if hasattr(config, "head_dim"):
|
||||
@ -487,10 +486,10 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
@ -195,9 +196,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
@ -615,10 +614,10 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -56,9 +57,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
@ -349,10 +348,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -161,9 +162,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
@ -509,10 +508,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
else None
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user