mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
more cleaning
This commit is contained in:
parent
c70742654b
commit
6c385626eb
@ -182,22 +182,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
total_ns=time.time_ns() - start,
|
total_ns=time.time_ns() - start,
|
||||||
)
|
)
|
||||||
|
|
||||||
import signal
|
|
||||||
|
|
||||||
class SignalHandler:
|
|
||||||
KEEP_PROCESSING = True
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
|
||||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
|
||||||
|
|
||||||
def exit_gracefully(self, signum, frame):
|
|
||||||
print(f"Exiting gracefully: Signal {signum}")
|
|
||||||
self.KEEP_PROCESSING = False
|
|
||||||
|
|
||||||
|
|
||||||
signal_handler = SignalHandler()
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
|
@ -14,15 +14,6 @@ from text_generation_server.utils.flash_attn_triton import triton_attention
|
|||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise ImportError("CUDA is not available")
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
|
||||||
is_sm90 = major == 9 and minor == 0
|
|
||||||
is_sm94 = major == 9 and minor == 4
|
|
||||||
|
|
||||||
HAS_FLASH_ATTN = False
|
HAS_FLASH_ATTN = False
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
@ -30,65 +21,78 @@ HAS_FLASH_ATTN_V2_ROCM = False
|
|||||||
ROCM_USE_FLASH_ATTN_V2_CK = False
|
ROCM_USE_FLASH_ATTN_V2_CK = False
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
||||||
|
|
||||||
if IS_ROCM_SYSTEM:
|
if IS_XPU_SYSTEM:
|
||||||
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
|
import intel_extension_for_pytorch as ipex
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
else:
|
if not torch.cuda.is_available():
|
||||||
ROCM_USE_FLASH_ATTN_V2_CK = True
|
raise ImportError("CUDA is not available")
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
is_sm94 = major == 9 and minor == 4
|
||||||
|
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
||||||
|
else:
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_CK = True
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
try:
|
||||||
except ImportError:
|
import flash_attn_2_cuda
|
||||||
architecture_suffix = ""
|
except ImportError:
|
||||||
if IS_CUDA_SYSTEM:
|
architecture_suffix = ""
|
||||||
architecture_suffix = "-cuda"
|
if IS_CUDA_SYSTEM:
|
||||||
|
architecture_suffix = "-cuda"
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
architecture_suffix = "-rocm"
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
|
)
|
||||||
|
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
|
"Flash Attention V2"
|
||||||
|
)
|
||||||
|
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
||||||
|
"Flash Attention V2"
|
||||||
|
)
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||||
|
except ImportError as e:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
architecture_suffix = "-rocm"
|
for idx in range(torch.cuda.device_count()):
|
||||||
raise ImportError(
|
if "MI210" not in torch.cuda.get_device_name(
|
||||||
"Flash Attention V2 is not installed.\n"
|
idx
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
raise ImportError(
|
||||||
)
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90):
|
)
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
||||||
"Flash Attention V2"
|
|
||||||
)
|
|
||||||
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
|
||||||
"Flash Attention V2"
|
|
||||||
)
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
|
||||||
except ImportError as e:
|
|
||||||
try:
|
|
||||||
import flash_attn_cuda
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
raise ImportError(
|
HAS_FLASH_ATTN = True
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
||||||
) from e
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
for idx in range(torch.cuda.device_count()):
|
|
||||||
if "MI210" not in torch.cuda.get_device_name(
|
|
||||||
idx
|
|
||||||
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
||||||
HAS_FLASH_ATTN = True
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
|
Loading…
Reference in New Issue
Block a user