fix benchmarker

This commit is contained in:
OlivierDehaene 2024-10-07 14:45:52 +02:00
parent 4ddea01c6e
commit 460e830444
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
4 changed files with 26 additions and 5 deletions

View File

@ -178,6 +178,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.clear_cache(None) .clear_cache(None)
.await .await
.expect("Unable to clear cache"); .expect("Unable to clear cache");
// Warmup shard
let max_batch_size = batch_size.iter().max().unwrap();
sharded_client
.warmup(
sequence_length,
sequence_length * max_batch_size,
(sequence_length + decode_length) * max_batch_size,
Some(*max_batch_size as usize),
)
.await
.expect("Unable to warmup");
tracing::info!("Connected"); tracing::info!("Connected");
// Run app // Run app

View File

@ -9,6 +9,9 @@ from typing import Callable, Any
class ExceptionInterceptor(AsyncServerInterceptor): class ExceptionInterceptor(AsyncServerInterceptor):
def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback
async def intercept( async def intercept(
self, self,
method: Callable, method: Callable,
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
# Runtime Error cannot be recovered from # Runtime Error cannot be recovered from
if isinstance(err, RuntimeError): if isinstance(err, RuntimeError):
exit(1) self.shutdown_callback()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1383,6 +1383,7 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache() empty_cache()
try: try:
@ -1402,7 +1403,7 @@ class FlashCausalLM(Model):
_, batch, _ = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e ) from e

View File

@ -47,9 +47,12 @@ class SignalHandler:
signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully)
def set_keep_processing(self, value: bool):
self.KEEP_PROCESSING = value
def exit_gracefully(self, signum, frame): def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}") print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False self.set_keep_processing(False)
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
@ -268,10 +271,12 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
signal_handler = SignalHandler()
set_adapter_to_index(adapter_to_index) set_adapter_to_index(adapter_to_index)
server = aio.server( server = aio.server(
interceptors=[ interceptors=[
ExceptionInterceptor(), ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
UDSOpenTelemetryAioServerInterceptor(), UDSOpenTelemetryAioServerInterceptor(),
], ],
options=[ options=[
@ -292,7 +297,6 @@ def serve(
await server.start() await server.start()
logger.info("Server started at {}".format(local_url)) logger.info("Server started at {}".format(local_url))
signal_handler = SignalHandler()
while signal_handler.KEEP_PROCESSING: while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)