fix: prefer version check over test op and avoid window_size_left if not flash attn2

This commit is contained in:
drbh 2024-08-01 15:06:09 +00:00
parent 5123925101
commit cae28dcbf1
6 changed files with 31 additions and 68 deletions

View File

@ -3,7 +3,6 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
import warnings
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
@ -172,43 +171,11 @@ def paged_attention(
try:
if major <= 8:
raise ImportError("Flash Attention V2 requires CUDA 11.0 or higher")
import flash_attn_2_cuda
# try forwarding to see if it works with all dummy inputs
batch_size = 1
num_heads = 1
head_dim = 1
seqlen = 1
try:
flash_attn_2_cuda.varlen_fwd(
torch.zeros(batch_size, num_heads, seqlen, head_dim), # q
torch.zeros(batch_size, num_heads, seqlen, head_dim), # k
torch.zeros(batch_size, num_heads, seqlen, head_dim), # v
None, # out (optional)
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k
None, # alibi_slopes (optional)
None, # q_padded (optional)
None, # k_padded (optional)
None, # v_padded (optional)
seqlen, # max_seqlen_q
seqlen, # max_seqlen_k
1.0, # softmax_scale
0.0, # softmax_lse (default value)
False, # is_causal
True, # return_softmax
-1, # window_size_left
-1, # window_size_right
0.0, # softmax_softcap
False, # deterministic
None, # rng_state (optional)
)
except RuntimeError as e:
raise ImportError(
"Flash Attention V2 is not supported on this machine. " f"Error: {e}"
) from e
V2 = True
except ImportError:
try:
@ -289,10 +256,9 @@ else:
window_size_left=-1,
softcap=None,
):
if window_size_left != -1:
warnings.warn(
"window_size_left is only available with flash attn v2. It will be ignored.",
UserWarning,
if window_size_left is not None and window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
@ -341,6 +307,3 @@ else:
0,
None,
)
SUPPORTS_WINDOWING = True

View File

@ -484,10 +484,14 @@ def get_model(
)
sliding_window = config_dict.get("sliding_window", -1)
is_max_input_within_sliding_window = (
max_input_tokens <= sliding_window if max_input_tokens is not None else False
)
if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
and is_max_input_within_sliding_window
):
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."

View File

@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -110,9 +111,7 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if hasattr(config, "head_dim"):
@ -487,10 +486,10 @@ class FlashMistralForCausalLM(torch.nn.Module):
),
weights=weights,
)
self.max_past = config.sliding_window
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
torch.tensor(self.max_past, device=weights.device)
if self.max_past
else None
)

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
FastLinear,
@ -195,9 +196,7 @@ class MixtralAttention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -615,10 +614,10 @@ class FlashMixtralForCausalLM(torch.nn.Module):
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
torch.tensor(self.max_past, device=weights.device)
if self.max_past
else None
)

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -56,9 +57,7 @@ class Qwen2Attention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -349,10 +348,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.max_past = config.sliding_window
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
torch.tensor(self.max_past, device=weights.device)
if self.max_past
else None
)

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -161,9 +162,7 @@ class Starcoder2Attention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -509,10 +508,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
weights=weights,
)
self.max_past = config.sliding_window
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
torch.tensor(self.max_past, device=weights.device)
if self.max_past
else None
)