more cleaning

This commit is contained in:
fxmarty 2024-05-02 15:44:38 +00:00
parent c70742654b
commit 6c385626eb
2 changed files with 68 additions and 80 deletions

View File

@ -182,22 +182,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
total_ns=time.time_ns() - start, total_ns=time.time_ns() - start,
) )
import signal
class SignalHandler:
KEEP_PROCESSING = True
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
signal_handler = SignalHandler()
def serve( def serve(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],

View File

@ -14,15 +14,6 @@ from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
is_sm94 = major == 9 and minor == 4
HAS_FLASH_ATTN = False HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False HAS_FLASH_ATTN_V2_ROCM = False
@ -30,65 +21,78 @@ HAS_FLASH_ATTN_V2_ROCM = False
ROCM_USE_FLASH_ATTN_V2_CK = False ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False ROCM_USE_FLASH_ATTN_V2_TRITON = False
if IS_ROCM_SYSTEM: if IS_XPU_SYSTEM:
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true": import intel_extension_for_pytorch as ipex
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.") if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
else: if not torch.cuda.is_available():
ROCM_USE_FLASH_ATTN_V2_CK = True raise ImportError("CUDA is not available")
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
is_sm94 = major == 9 and minor == 4
if IS_ROCM_SYSTEM:
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
try:
try: try:
import flash_attn_2_cuda try:
except ImportError: import flash_attn_2_cuda
architecture_suffix = "" except ImportError:
if IS_CUDA_SYSTEM: architecture_suffix = ""
architecture_suffix = "-cuda" if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm" for idx in range(torch.cuda.device_count()):
raise ImportError( if "MI210" not in torch.cuda.get_device_name(
"Flash Attention V2 is not installed.\n" idx
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " ) and "MI250" not in torch.cuda.get_device_name(idx):
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" raise ImportError(
) f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90): )
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): logger.warning(f"Unable to use Flash Attention V2: {e}")
raise ImportError( HAS_FLASH_ATTN = True
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def attention( def attention(