From 34b9289c3d04cacb62b9d03ae1548929efb0a415 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 6 May 2024 19:16:04 +0200 Subject: [PATCH] Moving to `SYSTEM` enum. --- .../layers/gptq/__init__.py | 4 +- .../layers/layernorm.py | 83 ++++++-- .../text_generation_server/layers/linear.py | 3 +- .../text_generation_server/layers/rotary.py | 17 +- .../models/cache_manager.py | 4 +- .../custom_modeling/flash_cohere_modeling.py | 16 +- .../custom_modeling/flash_dbrx_modeling.py | 13 +- .../custom_modeling/flash_mixtral_modeling.py | 12 +- .../custom_modeling/idefics_modeling.py | 12 +- .../models/flash_causal_lm.py | 39 +--- .../models/flash_llama.py | 4 +- .../models/flash_mistral.py | 4 +- .../models/flash_neox.py | 4 +- .../text_generation_server/models/flash_rw.py | 4 +- .../models/flash_santacoder.py | 4 +- .../utils/flash_attn.py | 197 ++++++++++-------- .../utils/import_utils.py | 32 ++- .../utils/paged_attention.py | 24 +-- 18 files changed, 286 insertions(+), 190 deletions(-) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 078dafc4..d3b10665 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,7 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import ( - IS_ROCM_SYSTEM, + SYSTEM, ) try: @@ -10,7 +10,7 @@ except Exception: major = 1 HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index 035070b1..8865be25 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -2,9 +2,7 @@ import torch from torch import nn from accelerate import init_empty_weights from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, + SYSTEM, ) @@ -35,19 +33,60 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias -if IS_CUDA_SYSTEM: +if SYSTEM == "cuda": import dropout_layer_norm -elif IS_ROCM_SYSTEM: + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": from vllm import layernorm_ops -elif IS_XPU_SYSTEM: + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + +elif SYSTEM == "xpu": import intel_extension_for_pytorch as ipex -else: - dropout_layer_norm = None - -class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if IS_XPU_SYSTEM: + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): res_out = hidden_states out = ipex.llm.functional.add_layer_norm( residual, hidden_states, self.weight, self.bias, self.eps, True @@ -55,7 +94,19 @@ class FastLayerNorm(nn.LayerNorm): if residual is not None: res_out = residual return out, res_out - elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + + +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if SYSTEM == "xpu": + res_out = hidden_states + out = ipex.llm.functional.add_layer_norm( + residual, hidden_states, self.weight, self.bias, self.eps, True + ) + if residual is not None: + res_out = residual + return out, res_out + elif hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": if residual is not None: hidden_states += residual residual = hidden_states @@ -102,7 +153,7 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if IS_XPU_SYSTEM: + if SYSTEM == "xpu": residual_out = hidden_states out = ipex.llm.functional.add_rms_norm( residual, @@ -131,7 +182,7 @@ class FastRMSNorm(nn.Module): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states, residual - elif IS_CUDA_SYSTEM: + elif SYSTEM == "cuda": # faster post attention rms norm ( normed_hidden_states, @@ -158,7 +209,7 @@ class FastRMSNorm(nn.Module): res = hidden_states return normed_hidden_states, res - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. if residual is not None: hidden_states += residual diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index dab7094f..822d43f3 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,4 +1,5 @@ import torch +from text_generation_server.utils.import_utils import SYSTEM class FastLinear(torch.nn.Module): @@ -126,7 +127,7 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) - if IS_ROCM_SYSTEM: + if SYSTEM == "rocm": raise NotImplementedError( "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " "to use Exllama/GPTQ kernels for AWQ inference." diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 3ca40894..5fc8d87c 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,15 +1,12 @@ import torch from torch import nn -from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, -) +from text_generation_server.utils.import_utils import SYSTEM -if IS_CUDA_SYSTEM: +if SYSTEM == "cuda": from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb -elif IS_ROCM_SYSTEM: +elif SYSTEM == "rocm": from vllm import pos_encoding_ops @@ -50,7 +47,7 @@ class PositionRotaryEmbedding(nn.Module): sin: torch.Tensor, ): # Such controlflows may add some overhead. - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": rotary_dim = cos.shape[-1] q1 = query[..., :rotary_dim] q2 = query[..., rotary_dim : 2 * rotary_dim] @@ -61,7 +58,7 @@ class PositionRotaryEmbedding(nn.Module): k2 = key[..., rotary_dim : 2 * rotary_dim] rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 @@ -69,7 +66,7 @@ class PositionRotaryEmbedding(nn.Module): # Inplace operation, updating query and key. pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True ) @@ -223,7 +220,7 @@ class PositionRotaryEmbedding(nn.Module): """ Return cos and sin for the asked position ids """ - if IS_ROCM_SYSTEM: + if SYSTEM == "rocm": # For RoCm, we always use float cos/sin to avoid a cast. # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 4c65e2dd..c7705fe8 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -2,7 +2,7 @@ import math import torch from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM BLOCK_SIZE: int = 16 # Will be set in warmup @@ -25,7 +25,7 @@ class CacheManager: self.repeat_slots = repeat_slots element_size = torch.tensor([], dtype=dtype).element_size() - if IS_XPU_SYSTEM: + if SYSTEM == "xpu": x = 1 else: x = self.block_size // element_size diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index d8774257..8c423eaf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -26,18 +26,22 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) -if IS_CUDA_SYSTEM: +if SYSTEM == "cuda": import dropout_layer_norm else: dropout_layer_norm = None @@ -52,7 +56,7 @@ class CohereRotary(PositionRotaryEmbedding): sin: torch.Tensor, ): # Such controlflows may add some overhead. - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": import rotary_emb q1 = query[..., ::2] @@ -64,7 +68,7 @@ class CohereRotary(PositionRotaryEmbedding): k2 = key[..., 1::2] rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import pos_encoding_ops # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. @@ -90,7 +94,7 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index bb08693a..29e83876 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,21 +21,26 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM -if not IS_XPU_SYSTEM: +if SYSTEM != "xpu": from vllm.model_executor.layers.fused_moe import fused_moe + from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( FastLinear, - FastLayerNorm, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, ) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) from text_generation_server.utils.log import log_once diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 1ba8bbd1..be2d6c45 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,9 +24,9 @@ import torch.distributed import numpy as np from torch import nn -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM -if not IS_XPU_SYSTEM: +if SYSTEM != "xpu": from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig @@ -36,14 +36,18 @@ from loguru import logger from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( FastLinear, - FastRMSNorm, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) class MixtralConfig(PretrainedConfig): diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 3adcd869..b0a7f04f 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -55,12 +55,14 @@ from text_generation_server.layers import ( PositionRotaryEmbedding, FastLinear, ) -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM -if IS_CUDA_SYSTEM: +if SYSTEM == "cuda": import dropout_layer_norm -elif IS_ROCM_SYSTEM: +elif SYSTEM == "rocm": from vllm import layernorm_ops +else: + raise RuntimeError(f"Unsupported system {SYSTEM}") @dataclass @@ -373,7 +375,7 @@ class IdeficsRMSNorm(nn.Module): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states - elif IS_CUDA_SYSTEM: + elif SYSTEM == "cuda": # faster post attention rms norm unwrap = False if len(hidden_states.shape) > 2: @@ -405,7 +407,7 @@ class IdeficsRMSNorm(nn.Module): normed_hidden_states = normed_hidden_states.view(*shape) return normed_hidden_states - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. if residual is not None: hidden_states += residual diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a6d0204f..62df1f59 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -12,10 +12,6 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict - -from text_generation_server.models import Model -from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( Batch, Tokens, @@ -32,13 +28,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION -tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, + empty_cache, + synchronize, + get_free_memory, ) +tracer = trace.get_tracer(__name__) + @dataclass class FlashCausalLMBatch(Batch): @@ -757,10 +754,8 @@ class FlashCausalLM(Model): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive - if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - torch.cuda.empty_cache() - elif IS_XPU_SYSTEM: - torch.xpu.empty_cache() + empty_cache() + try: cache_manager = set_cache_manager( batch.blocks, @@ -780,10 +775,7 @@ class FlashCausalLM(Model): f"You need to decrease `--max-batch-prefill-tokens`" ) from e - if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - torch.cuda.synchronize(self.device) - elif IS_XPU_SYSTEM: - torch.xpu.synchronize(self.device) + synchronize(self.device) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory @@ -791,20 +783,7 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - total_free_memory, _ = torch.cuda.mem_get_info(self.device) - total_gpu_memory = torch.cuda.get_device_properties( - self.device - ).total_memory - - free_memory = max( - 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory - ) - elif IS_XPU_SYSTEM: - total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory - free_memory = int(total_gpu_memory * 0.5) - else: - raise NotImplementedError("FlashModel is only available on GPU") + free_memory = get_free_memory(self.device, MEMORY_FRACTION) num_blocks = ( # Leave 5% for some wiggle room diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 609a188d..8ea70713 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM class FlashLlama(FlashCausalLM): @@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 85e93543..48304ad8 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index f82e27db..1119bdae 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index ccf38a0c..33298e1a 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e66f1bf8..66698a3a 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( Weights, ) -from text_generation_server.utils.import_utils import IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype else: diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 583a8f91..0830656d 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -2,13 +2,8 @@ import os import torch from loguru import logger -import math -from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, -) +from text_generation_server.utils.import_utils import SYSTEM if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -16,83 +11,22 @@ HAS_FLASH_ATTN = True HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False -if IS_XPU_SYSTEM: +if SYSTEM == "xpu": import intel_extension_for_pytorch as ipex -if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") - major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - - HAS_FLASH_ATTN = False - HAS_FLASH_ATTN_V2_CUDA = False - HAS_FLASH_ATTN_V2_ROCM = False - try: - try: - import flash_attn_2_cuda - except ImportError: - architecture_suffix = "" - if IS_CUDA_SYSTEM: - architecture_suffix = "-cuda" - elif IS_ROCM_SYSTEM: - architecture_suffix = "-rocm" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM - HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM - except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - elif IS_ROCM_SYSTEM: - for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name( - idx - ) and "MI250" not in torch.cuda.get_device_name(idx): - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True - - -def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, -): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - - if IS_XPU_SYSTEM: if window_size_left != -1: raise ValueError( f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." @@ -114,7 +48,77 @@ def attention( None, ) - if HAS_FLASH_ATTN_V2_CUDA: + +if SYSTEM in {"cuda", "rocm"}: + if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + HAS_FLASH_ATTN = False + HAS_FLASH_ATTN_V2_CUDA = False + HAS_FLASH_ATTN_V2_ROCM = False + try: + try: + import flash_attn_2_cuda + except ImportError: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda" + HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm" + except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + elif SYSTEM == "rocm": + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name( + idx + ) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +if HAS_FLASH_ATTN_V2_CUDA: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( q, k, @@ -136,7 +140,21 @@ def attention( False, None, ) - elif HAS_FLASH_ATTN_V2_ROCM: + +elif HAS_FLASH_ATTN_V2_ROCM: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") if window_size_left != -1: raise ValueError( f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." @@ -159,7 +177,19 @@ def attention( False, None, ) - elif HAS_FLASH_ATTN: + +elif HAS_FLASH_ATTN: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): if window_size_left != -1: raise NotImplementedError( "window_size_left is only available with flash attn v2" @@ -209,4 +239,5 @@ def attention( None, ) +else: raise NotImplementedError("flash attention is not installed") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index db205f4d..b9a41e15 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -10,6 +10,32 @@ def is_xpu_available(): return hasattr(torch, "xpu") and torch.xpu.is_available() -IS_ROCM_SYSTEM = torch.version.hip is not None -IS_CUDA_SYSTEM = torch.version.cuda is not None -IS_XPU_SYSTEM = is_xpu_available() +def get_cuda_free_memory(device, memory_fraction): + total_free_memory, _ = torch.cuda.mem_get_info(device) + total_gpu_memory = torch.cuda.get_device_properties(device).total_memory + free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) + return free_memory + + +def get_xpu_free_memory(device): + total_gpu_memory = torch.xpu.get_device_properties(device).total_memory + free_memory = int(total_gpu_memory * 0.5) + return free_memory + + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif torch.version.cuda is not None: + SYSTEM = "cuda" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif is_xpu_available(): + SYSTEM = "xpu" + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 62c0c893..1b31f7e7 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,13 +1,9 @@ import torch -from text_generation_server.utils.import_utils import ( - IS_CUDA_SYSTEM, - IS_ROCM_SYSTEM, - IS_XPU_SYSTEM, -) +from text_generation_server.utils.import_utils import SYSTEM _PARTITION_SIZE = 512 -if IS_XPU_SYSTEM: +if SYSTEM == "xpu": import intel_extension_for_pytorch as ipex @@ -18,17 +14,17 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": from vllm._C import cache_ops cache_ops.reshape_and_cache( key, value, key_cache, value_cache, slots, "auto", 1.0 ) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import cache_ops cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) - elif IS_XPU_SYSTEM: + elif SYSTEM == "xpu": ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slots ) @@ -68,7 +64,7 @@ def attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - if IS_XPU_SYSTEM: + if SYSTEM == "xpu": query = query.contiguous() return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, @@ -91,7 +87,7 @@ def attention( # to parallelize. use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": from vllm._C import ops ops.paged_attention_v1( @@ -109,7 +105,7 @@ def attention( "auto", 1.0, ) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import attention_ops attention_ops.paged_attention_v1( @@ -143,7 +139,7 @@ def attention( ) max_logits = torch.empty_like(exp_sums) - if IS_CUDA_SYSTEM: + if SYSTEM == "cuda": from vllm._C import ops ops.paged_attention_v2( @@ -164,7 +160,7 @@ def attention( "auto", 1.0, ) - elif IS_ROCM_SYSTEM: + elif SYSTEM == "rocm": from vllm import attention_ops attention_ops.paged_attention_v2(