From 4fa4da3cb6eebb9ea9afa46676ef5b28db971387 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 15 Oct 2024 16:12:00 +0200 Subject: [PATCH] Fixing non blocked attentions --- .../layers/attention/common.py | 104 +++--- .../layers/attention/cuda.py | 310 ++++++++++++------ .../models/flash_causal_lm.py | 8 +- .../text_generation_server/models/globals.py | 8 +- 4 files changed, 265 insertions(+), 165 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index 8f9d93a1..cb051265 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -5,68 +5,50 @@ import torch from typing import Optional -if ATTENTION in {"flashinfer", "flashdecoding"}: +@dataclass +class Seqlen: + input_lengths: torch.Tensor + cache_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + max_q: int + max_k: int - @dataclass - class Seqlen: - input_lengths: torch.Tensor - cache_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] - max_q: int - max_k: int + def __init__( + self, + input_lengths, + cache_lengths, + cu_seqlen_q=None, + max_q=None, + max_k=None, + ): + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + max_q = 1 + else: + assert max_q is not None + assert max_k is not None + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) - def __init__( - self, - input_lengths, - cache_lengths, - cu_seqlen_q=None, - max_q=None, - max_k=None, - ): - self.input_lengths = input_lengths - self.cache_lengths = cache_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape - if cu_seqlen_q is None: - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) - max_q = 1 - else: - assert max_q is not None - assert max_k is not None - cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + total = self.input_lengths + self.cache_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) - # cuda graphs don't like this and this is necessary to clamp within mistral - # Although FA2 might not want the clamping - # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.cache_lengths - torch.cumsum(total, -1, out=cu_seqlen_k[1:]) + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + self.max_q = max_q + self.max_k = max_k - self.cu_seqlen_q = cu_seqlen_q - self.cu_seqlen_k = cu_seqlen_k - self.max_q = max_q - self.max_k = max_k - - def clamp(self, max): - # Flash decoding doesn't need to clamp - return self - -else: - - @dataclass - class Seqlen: - input_lengths: torch.Tensor - cache_lengths: torch.Tensor - cu_seqlen_q: torch.Tensor - max_q: int - max_k: int - - def clamp(self, max): - if SYSTEM == "rocm": - return self - raise NotImplementedError("Not implemented seqlen for paged") - return Seqlen(torch.clamp(self.input_lengths, max=max)) + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index cd3ea369..929445d4 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -123,7 +123,7 @@ def paged_attention( else: if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") - input_lengths = seqlen.input_lengths + input_lengths = seqlen.input_lengths + seqlen.cache_lengths from vllm._C import ops out = torch.empty_like(query) @@ -243,118 +243,230 @@ if ATTENTION == "flashinfer": sm_scale=softmax_scale, window_left=window_size_left, ) +elif ATTENTION == "flashdecoding": + if V2: -elif V2: - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - out = torch.empty_like(q) - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( + def attention( q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, - None, - block_tables, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] + window_size_left=-1, + causal=True, + softcap=0.0, + ): + out = torch.empty_like(q) + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + return flash_attn_2_cuda.varlen_fwd( + q, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + None, + block_tables, + None, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + softcap, + False, + None, + )[0] -else: + else: - def attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap=None, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap=None, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + if softcap is not None: + raise NotImplementedError( + "softcap is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + out = torch.empty_like(q) + flash_attn_cuda.fwd( + q, + k, + v, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + False, + 0, + None, ) - if softcap is not None: - raise NotImplementedError("softcap is only available with flash attn v2") + return out +elif ATTENTION == "paged": + if V2: - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - out = torch.empty_like(q) - flash_attn_cuda.fwd( + def attention( q, - k, - v, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - 0.0, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, - False, - causal, - False, - 0, - None, - ) - return out + window_size_left=-1, + causal=True, + softcap=0.0, + ): + out = torch.empty_like(q) + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + return flash_attn_2_cuda.varlen_fwd( + q, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + None, + None, # block_tables, + None, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + softcap, + False, + None, + )[0] + + else: + + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap=None, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + if softcap is not None: + raise NotImplementedError( + "softcap is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + out = torch.empty_like(q) + flash_attn_cuda.fwd( + q, + k, + v, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + False, + 0, + None, + ) + return out +else: + raise RuntimeError(f"Unknwon attention {ATTENTION}") # Prefill in the cache with every kind of attention, unless we # have a configuration that requires flash-attention v1, which # does not support block tables. -PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 +PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2) __all__ = [ "PREFILL_IN_KV_CACHE", diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8222722a..7acc723a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1602,6 +1602,8 @@ class FlashCausalLM(Model): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices + print(slots) + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. @@ -1677,9 +1679,9 @@ class FlashCausalLM(Model): cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["cache_lengths"].zero_() - cuda_graph["cache_lengths"][ - : cache_lengths_tensor.shape[0] - ] = cache_lengths_tensor + cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = ( + cache_lengths_tensor + ) with self._forward_context( block_tables=cuda_graph["block_tables"], diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6bf8d3ff..0b60549a 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,13 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"} -log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") ATTENTION = os.getenv("ATTENTION", "flashinfer") +default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" +PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in { + "1", + "true", +} +log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected