mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix benchmarker
This commit is contained in:
parent
4ddea01c6e
commit
460e830444
@ -178,6 +178,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
.clear_cache(None)
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
|
|
||||||
|
// Warmup shard
|
||||||
|
let max_batch_size = batch_size.iter().max().unwrap();
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
sequence_length,
|
||||||
|
sequence_length * max_batch_size,
|
||||||
|
(sequence_length + decode_length) * max_batch_size,
|
||||||
|
Some(*max_batch_size as usize),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Unable to warmup");
|
||||||
|
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Run app
|
// Run app
|
||||||
|
@ -9,6 +9,9 @@ from typing import Callable, Any
|
|||||||
|
|
||||||
|
|
||||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
|
def __init__(self, shutdown_callback):
|
||||||
|
self.shutdown_callback = shutdown_callback
|
||||||
|
|
||||||
async def intercept(
|
async def intercept(
|
||||||
self,
|
self,
|
||||||
method: Callable,
|
method: Callable,
|
||||||
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
|
|
||||||
# Runtime Error cannot be recovered from
|
# Runtime Error cannot be recovered from
|
||||||
if isinstance(err, RuntimeError):
|
if isinstance(err, RuntimeError):
|
||||||
exit(1)
|
self.shutdown_callback()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -1383,6 +1383,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1402,7 +1403,7 @@ class FlashCausalLM(Model):
|
|||||||
_, batch, _ = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
@ -47,9 +47,12 @@ class SignalHandler:
|
|||||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||||
|
|
||||||
|
def set_keep_processing(self, value: bool):
|
||||||
|
self.KEEP_PROCESSING = value
|
||||||
|
|
||||||
def exit_gracefully(self, signum, frame):
|
def exit_gracefully(self, signum, frame):
|
||||||
print(f"Exiting gracefully: Signal {signum}")
|
print(f"Exiting gracefully: Signal {signum}")
|
||||||
self.KEEP_PROCESSING = False
|
self.set_keep_processing(False)
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
@ -268,10 +271,12 @@ def serve(
|
|||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
signal_handler = SignalHandler()
|
||||||
|
|
||||||
set_adapter_to_index(adapter_to_index)
|
set_adapter_to_index(adapter_to_index)
|
||||||
server = aio.server(
|
server = aio.server(
|
||||||
interceptors=[
|
interceptors=[
|
||||||
ExceptionInterceptor(),
|
ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
|
||||||
UDSOpenTelemetryAioServerInterceptor(),
|
UDSOpenTelemetryAioServerInterceptor(),
|
||||||
],
|
],
|
||||||
options=[
|
options=[
|
||||||
@ -292,7 +297,6 @@ def serve(
|
|||||||
await server.start()
|
await server.start()
|
||||||
|
|
||||||
logger.info("Server started at {}".format(local_url))
|
logger.info("Server started at {}".format(local_url))
|
||||||
signal_handler = SignalHandler()
|
|
||||||
while signal_handler.KEEP_PROCESSING:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user