From 17f5c3078be92f271fc43d04555087e903587b57 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:55:59 +0000 Subject: [PATCH] working & cached tunableop --- launcher/src/main.rs | 4 +--- .../models/flash_causal_lm.py | 9 +++++--- server/text_generation_server/server.py | 22 ++++++++++++++----- .../utils/paged_attention.py | 20 +++++++---------- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 40e7364f2..1a70e9913 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -672,9 +672,7 @@ fn shard_manager( // We received a shutdown signal if shutdown.load(Ordering::SeqCst) { - p.kill().unwrap(); - let _ = p.wait(); - tracing::info!("Shard terminated"); + terminate("shard", p, Duration::from_secs(90)).unwrap(); return; } diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 31d965e0d..6e14de814 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -766,6 +766,9 @@ class FlashCausalLM(Model): ) max_bt = batch.max_blocks max_s = max_bt * get_cache_manager().block_size + + if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): + logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.") _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( @@ -820,10 +823,10 @@ class FlashCausalLM(Model): else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - # TODO: fix - if IS_ROCM_SYSTEM and False: - total_seqlens = list(range(16)) + if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): + total_seqlens = list(range(2)) for seqlen in total_seqlens: + logger.info(f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen, max_s, max_bt) return int(num_blocks * BLOCK_SIZE) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 495c2c0cd..ef05099b2 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -166,6 +166,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): total_ns=time.time_ns() - start, ) +import signal + +class SignalHandler: + KEEP_PROCESSING = True + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully: Signal {signum}") + self.KEEP_PROCESSING = False + + +signal_handler = SignalHandler() def serve( model_id: str, @@ -231,11 +246,8 @@ def serve( logger.info("Server started at {}".format(local_url)) - try: - await server.wait_for_termination() - except KeyboardInterrupt: - logger.info("Signal received. Shutting down") - await server.stop(0) + while signal_handler.KEEP_PROCESSING: + await asyncio.sleep(0.5) asyncio.run( serve_inner( diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 487a3a72e..d47e0821d 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -4,6 +4,11 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY _PARTITION_SIZE = 512 +try: + from vllm._C import cache_ops +except Exception as e: + raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}") + def reshape_and_cache( key: torch.Tensor, @@ -12,18 +17,9 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if IS_CUDA_SYSTEM: - from vllm._C import cache_ops - - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) - elif IS_ROCM_SYSTEM: - from vllm import cache_ops - - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) - else: - raise ValueError("vllm is not supported on your system") + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def attention(