From fc41f0784a667588d8ae4ae3d21b055076e8f13c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 15 Oct 2024 18:46:56 +0200 Subject: [PATCH] lint fix. --- .../layers/attention/common.py | 52 ------------------- .../layers/attention/cuda.py | 3 ++ .../models/flash_causal_lm.py | 6 +-- 3 files changed, 6 insertions(+), 55 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index cb7c8465..a3b919ee 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,6 +1,4 @@ from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION import torch from typing import Optional @@ -52,53 +50,3 @@ class Seqlen: def clamp(self, max): # Flash decoding doesn't need to clamp return self -======= -if ATTENTION in {"flashinfer", "flashdecoding"}: - - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] - max_q: int - max_k: int - - def __init__( - self, - input_lengths, - prefix_lengths, - cu_seqlen_q=None, - max_q=None, - max_k=None, - ): - self.input_lengths = input_lengths - self.prefix_lengths = prefix_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape - if cu_seqlen_q is None: - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) - max_q = 1 - else: - assert max_q is not None - assert max_k is not None - cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) - - # cuda graphs don't like this and this is necessary to clamp within mistral - # Although FA2 might not want the clamping - # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.prefix_lengths - torch.cumsum(total, -1, out=cu_seqlen_k[1:]) - - self.cu_seqlen_q = cu_seqlen_q - self.cu_seqlen_k = cu_seqlen_k - self.max_q = max_q - self.max_k = max_k - - def clamp(self, max): - # Flash decoding doesn't need to clamp - return self \ No newline at end of file diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 929445d4..265a8ae4 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -243,6 +243,7 @@ if ATTENTION == "flashinfer": sm_scale=softmax_scale, window_left=window_size_left, ) + elif ATTENTION == "flashdecoding": if V2: @@ -351,6 +352,7 @@ elif ATTENTION == "flashdecoding": None, ) return out + elif ATTENTION == "paged": if V2: @@ -459,6 +461,7 @@ elif ATTENTION == "paged": None, ) return out + else: raise RuntimeError(f"Unknwon attention {ATTENTION}") diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1e0e9176..c9b7decd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1679,9 +1679,9 @@ class FlashCausalLM(Model): cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["cache_lengths"].zero_() - cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = ( - cache_lengths_tensor - ) + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"],