From 460e83044493f5efed7112b10fa0bbef3dff02e9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:45:52 +0200 Subject: [PATCH] fix benchmarker --- benchmark/src/main.rs | 13 +++++++++++++ server/text_generation_server/interceptor.py | 5 ++++- .../models/flash_causal_lm.py | 3 ++- server/text_generation_server/server.py | 10 +++++++--- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2ee3d7c5..0bb5dc0c 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -178,6 +178,19 @@ fn main() -> Result<(), Box> { .clear_cache(None) .await .expect("Unable to clear cache"); + + // Warmup shard + let max_batch_size = batch_size.iter().max().unwrap(); + sharded_client + .warmup( + sequence_length, + sequence_length * max_batch_size, + (sequence_length + decode_length) * max_batch_size, + Some(*max_batch_size as usize), + ) + .await + .expect("Unable to warmup"); + tracing::info!("Connected"); // Run app diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 57df1725..a5c023e4 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -9,6 +9,9 @@ from typing import Callable, Any class ExceptionInterceptor(AsyncServerInterceptor): + def __init__(self, shutdown_callback): + self.shutdown_callback = shutdown_callback + async def intercept( self, method: Callable, @@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor): # Runtime Error cannot be recovered from if isinstance(err, RuntimeError): - exit(1) + self.shutdown_callback() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4c3024f4..6e8a0097 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1383,6 +1383,7 @@ class FlashCausalLM(Model): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive + self.kv_cache = [] empty_cache() try: @@ -1402,7 +1403,7 @@ class FlashCausalLM(Model): _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) from e diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index cc7979d4..da85d19d 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -47,9 +47,12 @@ class SignalHandler: signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully) + def set_keep_processing(self, value: bool): + self.KEEP_PROCESSING = value + def exit_gracefully(self, signum, frame): print(f"Exiting gracefully: Signal {signum}") - self.KEEP_PROCESSING = False + self.set_keep_processing(False) class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): @@ -268,10 +271,12 @@ def serve( logger.exception("Error when initializing model") raise + signal_handler = SignalHandler() + set_adapter_to_index(adapter_to_index) server = aio.server( interceptors=[ - ExceptionInterceptor(), + ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)), UDSOpenTelemetryAioServerInterceptor(), ], options=[ @@ -292,7 +297,6 @@ def serve( await server.start() logger.info("Server started at {}".format(local_url)) - signal_handler = SignalHandler() while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5)