From 265c76d328edc537db601d2f6b83cc402241f233 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 16 May 2024 14:46:47 +0000 Subject: [PATCH] black --- .github/workflows/build.yaml | 2 +- docs/source/installation_amd.md | 2 +- .../exllama_kernels/hip_compat.cuh | 2 +- .../text_generation_server/layers/linear.py | 1 + server/text_generation_server/models/bloom.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 10 ++- .../custom_modeling/flash_mistral_modeling.py | 10 ++- .../custom_modeling/idefics_modeling.py | 1 + .../models/flash_causal_lm.py | 34 +++++--- .../models/flash_cohere.py | 2 +- .../models/flash_dbrx.py | 2 +- .../models/flash_gemma.py | 2 +- .../models/flash_llama.py | 2 +- .../models/flash_mistral.py | 2 +- .../models/flash_neox.py | 2 +- .../models/flash_phi.py | 2 +- .../models/flash_qwen2.py | 2 +- .../text_generation_server/models/flash_rw.py | 2 +- .../models/flash_santacoder.py | 2 +- .../models/flash_starcoder2.py | 2 +- .../models/galactica.py | 2 +- .../text_generation_server/models/gpt_neox.py | 2 +- .../text_generation_server/models/idefics.py | 2 +- server/text_generation_server/models/mamba.py | 2 +- server/text_generation_server/models/mpt.py | 2 +- server/text_generation_server/models/opt.py | 2 +- server/text_generation_server/models/phi.py | 2 +- .../models/santacoder.py | 2 +- server/text_generation_server/models/t5.py | 2 +- server/text_generation_server/server.py | 1 + .../utils/flash_attn.py | 14 +++- .../utils/flash_attn_triton.py | 84 +++++++++---------- .../utils/paged_attention.py | 4 +- 33 files changed, 123 insertions(+), 84 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 747a8cc6..cc9dea9a 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -425,4 +425,4 @@ jobs: - name: Run tests run: | export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - pytest -s -vv integration-tests \ No newline at end of file + pytest -s -vv integration-tests diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 279b1e6e..5111bab5 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -35,4 +35,4 @@ By default, as its performances have experimentally been better, Triton implemen The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: * Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints. -* Kernel for sliding window attention (Mistral) \ No newline at end of file +* Kernel for sliding window attention (Mistral) diff --git a/server/exllama_kernels/exllama_kernels/hip_compat.cuh b/server/exllama_kernels/exllama_kernels/hip_compat.cuh index d8cbcc49..f2a3dcad 100644 --- a/server/exllama_kernels/exllama_kernels/hip_compat.cuh +++ b/server/exllama_kernels/exllama_kernels/hip_compat.cuh @@ -11,7 +11,7 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) { __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { return _Float16_2{ - _Float16_2{static_cast<_Float16>(1.0f), + _Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data}; } diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 27ec9aca..987b6a7b 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -8,6 +8,7 @@ if SYSTEM == "rocm": except Exception as e: raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + class FastLinear(torch.nn.Module): def __init__( self, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index b3907768..6e7f2f1e 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -47,7 +47,7 @@ class BLOOMSharded(CausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index df8d1f52..47758d30 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -197,7 +197,9 @@ class LlamaMLP(nn.Module): else lambda x: torch.nn.functional.gelu( x, approximate=( - "tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + "tanh" + if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none" ), ) ) @@ -229,7 +231,11 @@ class LlamaMLP(nn.Module): ) def forward(self, hidden_states): - if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1: + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.shape[0] == 1 + ): out = torch.empty( hidden_states.shape[0], self.intermediate_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 1532757f..21edc79e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -266,7 +266,9 @@ class MistralMLP(nn.Module): else lambda x: torch.nn.functional.gelu( x, approximate=( - "tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + "tanh" + if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none" ), ) ) @@ -289,7 +291,11 @@ class MistralMLP(nn.Module): ) def forward(self, hidden_states): - if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1: + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.shape[0] == 1 + ): out = torch.empty( hidden_states.shape[0], self.intermediate_size, diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 3b06f737..d0c84308 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -64,6 +64,7 @@ elif SYSTEM == "rocm": else: raise RuntimeError(f"Unsupported system {SYSTEM}") + @dataclass class BaseModelOutputWithPastImage(BaseModelOutputWithPast): image_hidden_states: Optional[torch.FloatTensor] = None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 14f2d8c5..92d8aa5c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -831,28 +831,40 @@ class FlashCausalLM(Model): torch.cuda.tunable.tuning_enable(True) if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False): - tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")] + tuning_sequences = [ + int(val) + for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") + ] else: tuning_sequences = [1, 2, 4, 8, 16, 32] - - tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv") - logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.") + tunableop_filepath = os.path.join( + "/data", + f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + ) + + logger.info( + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}." + ) if os.path.isfile(tunableop_filepath): - logger.info(f"The file {tunableop_filepath} already exists and will be reused.") + logger.info( + f"The file {tunableop_filepath} already exists and will be reused." + ) torch.cuda.tunable.read_file(tunableop_filepath) - + os.makedirs("/data", exist_ok=True) - + for seqlen in tuning_sequences: logger.info(f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) else: - logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.") - + logger.info( + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." + ) + if CUDA_GRAPHS: try: logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") @@ -877,7 +889,9 @@ class FlashCausalLM(Model): self.model.forward( input_ids=input_ids, position_ids=position_ids, - cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32), + cu_seqlen_prefill=torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ), kv_cache=get_cache_manager().kv_cache, block_tables=None, input_lengths=None, diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index 5da84eed..d955fca7 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -29,7 +29,7 @@ class FlashCohere(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index ed0cdf3d..60c032ad 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -31,7 +31,7 @@ class FlashDbrx(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index db38f50f..8783a259 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -29,7 +29,7 @@ class FlashGemma(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 10c443de..5c552379 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index c5e92116..018ddd10 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -319,7 +319,7 @@ class BaseFlashMistral(FlashCausalLM): tokenizer_class=AutoTokenizer, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 17de8114..eda1d658 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -30,7 +30,7 @@ class FlashNeoXSharded(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 8111d142..9dd576bc 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -30,7 +30,7 @@ class FlashPhi(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 16d31e04..482d5fd6 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -35,7 +35,7 @@ class FlashQwen2(BaseFlashMistral): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 57fee67a..68c01e5a 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -31,7 +31,7 @@ class FlashRWSharded(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 134925d8..9b48a5e6 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -34,7 +34,7 @@ class FlashSantacoderSharded(FlashCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 91055d79..d498725a 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -34,7 +34,7 @@ class FlashStarcoder2(BaseFlashMistral): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 16ad1ae5..c164703c 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -172,7 +172,7 @@ class GalacticaSharded(CausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index ed3eb61c..83007609 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -29,7 +29,7 @@ class GPTNeoxSharded(CausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index f4ffdafb..a3000cfa 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -36,7 +36,7 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 80e56f2a..6a681aa1 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -413,7 +413,7 @@ class Mamba(Model): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, _rank, world_size = initialize_torch_distributed() if world_size > 1: raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 753b3ba9..4525c417 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -48,7 +48,7 @@ class MPTSharded(CausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 62ace3f9..f706c7dc 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -27,7 +27,7 @@ class OPTSharded(CausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index cea9165d..dff76084 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -27,7 +27,7 @@ class Phi(CausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, _rank, _world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device("cuda") diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 17251113..e8ca2235 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -24,7 +24,7 @@ class SantaCoder(CausalLM): trust_remote_code: bool = False, ): self.model_id = model_id - + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 568a08b8..dc49d7bd 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -30,7 +30,7 @@ class T5Sharded(Seq2SeqLM): trust_remote_code: bool = False, ): self.model_id = model_id - + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index ec5aa51f..92126fe6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -187,6 +187,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): total_ns=time.time_ns() - start, ) + def serve( model_id: str, revision: Optional[str], diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 94705c41..c5fd7830 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -64,12 +64,17 @@ if SYSTEM in {"cuda", "rocm"}: is_sm94 = major == 9 and minor == 4 if SYSTEM == "rocm": - if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1": + if ( + os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" + or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1" + ): ROCM_USE_FLASH_ATTN_V2_TRITON = True logger.info("ROCm: using Flash Attention 2 Triton implementation.") else: ROCM_USE_FLASH_ATTN_V2_CK = True - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + logger.info( + "ROCm: using Flash Attention 2 Composable Kernel implementation." + ) try: try: @@ -158,6 +163,7 @@ if HAS_FLASH_ATTN_V2_CUDA: ) elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: + def attention( q, k, @@ -192,8 +198,9 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: False, None, ) + elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: - + def attention( q, k, @@ -217,6 +224,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: softmax_scale, ) return output + elif HAS_FLASH_ATTN: def attention( diff --git a/server/text_generation_server/utils/flash_attn_triton.py b/server/text_generation_server/utils/flash_attn_triton.py index 9167b1f4..3fe32231 100644 --- a/server/text_generation_server/utils/flash_attn_triton.py +++ b/server/text_generation_server/utils/flash_attn_triton.py @@ -46,16 +46,16 @@ def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, - stride).to(tl.uint32) + rng_offsets = dropout_offsets( + philox_seed, philox_offset, dropout_p, m, n, stride + ).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, - stride) + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep @@ -65,9 +65,9 @@ def load_fn(block_ptr, first, second, pad): if first and second: tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) else: tensor = tl.load(block_ptr) return tensor @@ -133,9 +133,7 @@ def _attn_fwd_inner( # if not is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], - actual_seqlen_k, - dtype=tl.int32) + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -146,8 +144,9 @@ def _attn_fwd_inner( # -- compute qk ---- qk += tl.dot(q, k) if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") + bias = load_fn( + bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" + ) # While bias is added after multiplying qk with sm_scale, our # optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. @@ -159,9 +158,12 @@ def _attn_fwd_inner( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) + philox_offset = ( + batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + + start_n + - BLOCK_N + ) keep = dropout_mask( philox_seed, philox_offset, @@ -173,8 +175,7 @@ def _attn_fwd_inner( if RETURN_ENCODED_SOFTMAX: tl.store( encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), + tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), ) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: @@ -202,8 +203,9 @@ def _attn_fwd_inner( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, BLOCK_N) + ) return acc, l_i, m_i @@ -341,7 +343,7 @@ def attn_fwd( philox_offset_base, encoded_softmax, HQ: tl.constexpr, - HK:tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, @@ -392,15 +394,17 @@ def attn_fwd( # This captures the decrease in n_blocks if we have a rectangular attn # matrix n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N + ) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) # If we have no blocks after adjusting for seqlen deltas, this WG is # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) + o_offset = ( + off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + ) O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), @@ -436,11 +440,10 @@ def attn_fwd( n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), @@ -449,8 +452,7 @@ def attn_fwd( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), @@ -459,8 +461,7 @@ def attn_fwd( block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk V_block_ptr = tl.make_block_ptr( base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), @@ -481,9 +482,9 @@ def attn_fwd( else: bias_ptr = None if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k + batch_philox_offset = ( + philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k + ) else: batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. @@ -578,8 +579,9 @@ def attn_fwd( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, n_full_blocks) + ) acc, l_i, m_i = _attn_fwd_inner( acc, l_i, @@ -626,12 +628,11 @@ def attn_fwd( acc = acc.to(Out.type.element_ty) if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) + out_mask_boundary = tl.full( + (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 + ) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] >= - out_mask_boundary[None, :]) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE @@ -649,8 +650,7 @@ def attn_fwd( # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), @@ -813,4 +813,4 @@ class _attention(torch.autograd.Function): return o, encoded_softmax -triton_attention = _attention.apply \ No newline at end of file +triton_attention = _attention.apply diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 87cc4a83..6cc30e6d 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -10,7 +10,9 @@ else: from vllm._C import cache_ops from vllm._C import ops except Exception as e: - raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}") + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) def reshape_and_cache(