From 6c385626eb9e1de47c9e1d0acef1b449d9712bcd Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 May 2024 15:44:38 +0000 Subject: [PATCH] more cleaning --- server/text_generation_server/server.py | 16 --- .../utils/flash_attn.py | 132 +++++++++--------- 2 files changed, 68 insertions(+), 80 deletions(-) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e71fcdf0..a43c06cb 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -182,22 +182,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): total_ns=time.time_ns() - start, ) -import signal - -class SignalHandler: - KEEP_PROCESSING = True - - def __init__(self): - signal.signal(signal.SIGINT, self.exit_gracefully) - signal.signal(signal.SIGTERM, self.exit_gracefully) - - def exit_gracefully(self, signum, frame): - print(f"Exiting gracefully: Signal {signum}") - self.KEEP_PROCESSING = False - - -signal_handler = SignalHandler() - def serve( model_id: str, revision: Optional[str], diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 32adc6de..a15cbbbc 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -14,15 +14,6 @@ from text_generation_server.utils.flash_attn_triton import triton_attention if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") -if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -is_sm8x = major == 8 and minor >= 0 -is_sm90 = major == 9 and minor == 0 -is_sm94 = major == 9 and minor == 4 - HAS_FLASH_ATTN = False HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False @@ -30,65 +21,78 @@ HAS_FLASH_ATTN_V2_ROCM = False ROCM_USE_FLASH_ATTN_V2_CK = False ROCM_USE_FLASH_ATTN_V2_TRITON = False -if IS_ROCM_SYSTEM: - if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true": - ROCM_USE_FLASH_ATTN_V2_TRITON = True - logger.info("ROCm: using Flash Attention 2 Triton implementation.") - else: - ROCM_USE_FLASH_ATTN_V2_CK = True - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex + +if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + is_sm94 = major == 9 and minor == 4 + + if IS_ROCM_SYSTEM: + if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true": + ROCM_USE_FLASH_ATTN_V2_TRITON = True + logger.info("ROCm: using Flash Attention 2 Triton implementation.") + else: + ROCM_USE_FLASH_ATTN_V2_CK = True + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") -try: try: - import flash_attn_2_cuda - except ImportError: - architecture_suffix = "" - if IS_CUDA_SYSTEM: - architecture_suffix = "-cuda" + try: + import flash_attn_2_cuda + except ImportError: + architecture_suffix = "" + if IS_CUDA_SYSTEM: + architecture_suffix = "-cuda" + elif IS_ROCM_SYSTEM: + architecture_suffix = "-rocm" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94): + raise ImportError( + f"AMD GPU with compute capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM + HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM + except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e elif IS_ROCM_SYSTEM: - architecture_suffix = "-rocm" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94): - raise ImportError( - f"AMD GPU with compute capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM - HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM -except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name( + idx + ) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) - if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - elif IS_ROCM_SYSTEM: - for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name( - idx - ) and "MI250" not in torch.cuda.get_device_name(idx): - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True def attention(