mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix benchmarker
This commit is contained in:
parent
4ddea01c6e
commit
460e830444
@ -178,6 +178,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
|
||||
// Warmup shard
|
||||
let max_batch_size = batch_size.iter().max().unwrap();
|
||||
sharded_client
|
||||
.warmup(
|
||||
sequence_length,
|
||||
sequence_length * max_batch_size,
|
||||
(sequence_length + decode_length) * max_batch_size,
|
||||
Some(*max_batch_size as usize),
|
||||
)
|
||||
.await
|
||||
.expect("Unable to warmup");
|
||||
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Run app
|
||||
|
@ -9,6 +9,9 @@ from typing import Callable, Any
|
||||
|
||||
|
||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
def __init__(self, shutdown_callback):
|
||||
self.shutdown_callback = shutdown_callback
|
||||
|
||||
async def intercept(
|
||||
self,
|
||||
method: Callable,
|
||||
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
|
||||
# Runtime Error cannot be recovered from
|
||||
if isinstance(err, RuntimeError):
|
||||
exit(1)
|
||||
self.shutdown_callback()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -1383,6 +1383,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
try:
|
||||
@ -1402,7 +1403,7 @@ class FlashCausalLM(Model):
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
) from e
|
||||
|
||||
|
@ -47,9 +47,12 @@ class SignalHandler:
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||
|
||||
def set_keep_processing(self, value: bool):
|
||||
self.KEEP_PROCESSING = value
|
||||
|
||||
def exit_gracefully(self, signum, frame):
|
||||
print(f"Exiting gracefully: Signal {signum}")
|
||||
self.KEEP_PROCESSING = False
|
||||
self.set_keep_processing(False)
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
@ -268,10 +271,12 @@ def serve(
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
signal_handler = SignalHandler()
|
||||
|
||||
set_adapter_to_index(adapter_to_index)
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
|
||||
UDSOpenTelemetryAioServerInterceptor(),
|
||||
],
|
||||
options=[
|
||||
@ -292,7 +297,6 @@ def serve(
|
||||
await server.start()
|
||||
|
||||
logger.info("Server started at {}".format(local_url))
|
||||
signal_handler = SignalHandler()
|
||||
while signal_handler.KEEP_PROCESSING:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user