mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix: prefer version check over test op and avoid window_size_left if not flash attn2
This commit is contained in:
parent
5123925101
commit
cae28dcbf1
@ -3,7 +3,6 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import warnings
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
@ -172,43 +171,11 @@ def paged_attention(
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if major <= 8:
|
||||||
|
raise ImportError("Flash Attention V2 requires CUDA 11.0 or higher")
|
||||||
|
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
# try forwarding to see if it works with all dummy inputs
|
|
||||||
batch_size = 1
|
|
||||||
num_heads = 1
|
|
||||||
head_dim = 1
|
|
||||||
seqlen = 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
flash_attn_2_cuda.varlen_fwd(
|
|
||||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # q
|
|
||||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # k
|
|
||||||
torch.zeros(batch_size, num_heads, seqlen, head_dim), # v
|
|
||||||
None, # out (optional)
|
|
||||||
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q
|
|
||||||
torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k
|
|
||||||
None, # alibi_slopes (optional)
|
|
||||||
None, # q_padded (optional)
|
|
||||||
None, # k_padded (optional)
|
|
||||||
None, # v_padded (optional)
|
|
||||||
seqlen, # max_seqlen_q
|
|
||||||
seqlen, # max_seqlen_k
|
|
||||||
1.0, # softmax_scale
|
|
||||||
0.0, # softmax_lse (default value)
|
|
||||||
False, # is_causal
|
|
||||||
True, # return_softmax
|
|
||||||
-1, # window_size_left
|
|
||||||
-1, # window_size_right
|
|
||||||
0.0, # softmax_softcap
|
|
||||||
False, # deterministic
|
|
||||||
None, # rng_state (optional)
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention V2 is not supported on this machine. " f"Error: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
V2 = True
|
V2 = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
@ -289,10 +256,9 @@ else:
|
|||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
softcap=None,
|
softcap=None,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left is not None and window_size_left != -1:
|
||||||
warnings.warn(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2. It will be ignored.",
|
"window_size_left is only available with flash attn v2"
|
||||||
UserWarning,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
@ -341,6 +307,3 @@ else:
|
|||||||
0,
|
0,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = True
|
|
||||||
|
@ -484,10 +484,14 @@ def get_model(
|
|||||||
)
|
)
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
|
|
||||||
|
is_max_input_within_sliding_window = (
|
||||||
|
max_input_tokens <= sliding_window if max_input_tokens is not None else False
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(sliding_window is not None and sliding_window != -1)
|
(sliding_window is not None and sliding_window != -1)
|
||||||
and not SUPPORTS_WINDOWING
|
and not SUPPORTS_WINDOWING
|
||||||
and max_input_tokens > sliding_window
|
and is_max_input_within_sliding_window
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
|
|||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -110,9 +111,7 @@ class MistralConfig(PretrainedConfig):
|
|||||||
class MistralAttention(torch.nn.Module):
|
class MistralAttention(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights, layer_id):
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
config.sliding_window if config.sliding_window is not None else -1
|
|
||||||
)
|
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
if hasattr(config, "head_dim"):
|
if hasattr(config, "head_dim"):
|
||||||
@ -487,10 +486,10 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
self.max_past_tensor = (
|
self.max_past_tensor = (
|
||||||
torch.tensor(config.sliding_window, device=weights.device)
|
torch.tensor(self.max_past, device=weights.device)
|
||||||
if self.max_past is not None
|
if self.max_past
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
|
|||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
@ -195,9 +196,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
config.sliding_window if config.sliding_window is not None else -1
|
|
||||||
)
|
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
@ -615,10 +614,10 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
self.max_past_tensor = (
|
self.max_past_tensor = (
|
||||||
torch.tensor(config.sliding_window, device=weights.device)
|
torch.tensor(self.max_past, device=weights.device)
|
||||||
if self.max_past is not None
|
if self.max_past
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
|
|||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -56,9 +57,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
config.sliding_window if config.sliding_window is not None else -1
|
|
||||||
)
|
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
@ -349,10 +348,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
self.max_past_tensor = (
|
self.max_past_tensor = (
|
||||||
torch.tensor(config.sliding_window, device=weights.device)
|
torch.tensor(self.max_past, device=weights.device)
|
||||||
if self.max_past is not None
|
if self.max_past
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
|||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -161,9 +162,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
config.sliding_window if config.sliding_window is not None else -1
|
|
||||||
)
|
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
@ -509,10 +508,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||||
self.max_past_tensor = (
|
self.max_past_tensor = (
|
||||||
torch.tensor(config.sliding_window, device=weights.device)
|
torch.tensor(self.max_past, device=weights.device)
|
||||||
if self.max_past is not None
|
if self.max_past
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user