2022-10-18 13:19:03 +00:00
|
|
|
use crate::{
|
|
|
|
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
|
|
|
|
};
|
2022-10-11 16:14:39 +00:00
|
|
|
use axum::extract::Extension;
|
2022-10-14 13:56:21 +00:00
|
|
|
use axum::http::StatusCode;
|
2022-10-15 18:21:50 +00:00
|
|
|
use axum::routing::{get, post};
|
2022-10-14 13:56:21 +00:00
|
|
|
use axum::{Json, Router};
|
2022-10-17 12:59:00 +00:00
|
|
|
use bloom_inference_client::ShardedClient;
|
2022-10-14 13:56:21 +00:00
|
|
|
use std::net::SocketAddr;
|
2022-10-18 13:19:03 +00:00
|
|
|
use std::sync::Arc;
|
|
|
|
use std::time::Duration;
|
2022-10-11 14:50:54 +00:00
|
|
|
use tokenizers::Tokenizer;
|
2022-10-18 13:19:03 +00:00
|
|
|
use tokio::signal;
|
|
|
|
use tokio::sync::Semaphore;
|
2022-10-11 08:36:51 +00:00
|
|
|
use tokio::time::Instant;
|
|
|
|
use tracing::instrument;
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Server shared state
|
|
|
|
#[derive(Clone)]
|
|
|
|
struct ServerState {
|
|
|
|
validation: Validation,
|
|
|
|
batcher: Batcher,
|
|
|
|
limit_concurrent_requests: Arc<Semaphore>,
|
2022-10-11 08:36:51 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Health check method
|
2022-10-14 13:56:21 +00:00
|
|
|
#[instrument(skip(state), fields(time, time_per_token))]
|
2022-10-18 13:19:03 +00:00
|
|
|
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
|
|
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
|
|
|
// be a bit too slow for a health check.
|
|
|
|
// What we should do instead if check if the gRPC channels are still healthy.
|
|
|
|
|
|
|
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
|
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
|
|
|
(
|
|
|
|
StatusCode::TOO_MANY_REQUESTS,
|
|
|
|
"Model is overloaded".to_string(),
|
|
|
|
)
|
|
|
|
})?;
|
|
|
|
|
|
|
|
// Send a small inference request
|
2022-10-17 12:59:00 +00:00
|
|
|
state
|
2022-10-17 16:27:33 +00:00
|
|
|
.batcher
|
2022-10-14 13:56:21 +00:00
|
|
|
.infer(
|
|
|
|
1,
|
|
|
|
GenerateRequest {
|
|
|
|
inputs: "liveness".to_string(),
|
|
|
|
parameters: GenerateParameters {
|
|
|
|
temperature: 1.0,
|
|
|
|
top_k: 0,
|
|
|
|
top_p: 1.0,
|
|
|
|
do_sample: false,
|
|
|
|
max_new_tokens: 1,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
2022-10-17 12:59:00 +00:00
|
|
|
.await?;
|
|
|
|
Ok(())
|
2022-10-14 13:56:21 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate method
|
2022-10-11 16:14:39 +00:00
|
|
|
#[instrument(skip(state), fields(time, time_per_token))]
|
2022-10-11 08:36:51 +00:00
|
|
|
async fn generate(
|
2022-10-11 16:14:39 +00:00
|
|
|
state: Extension<ServerState>,
|
2022-10-11 08:36:51 +00:00
|
|
|
req: Json<GenerateRequest>,
|
2022-10-18 13:19:03 +00:00
|
|
|
) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
|
2022-10-11 08:36:51 +00:00
|
|
|
let start = Instant::now();
|
2022-10-18 13:19:03 +00:00
|
|
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
|
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
|
|
|
(
|
|
|
|
StatusCode::TOO_MANY_REQUESTS,
|
|
|
|
"Model is overloaded".to_string(),
|
|
|
|
)
|
|
|
|
})?;
|
2022-10-11 08:36:51 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Validate request
|
2022-10-17 12:59:00 +00:00
|
|
|
let (input_length, validated_request) = state
|
2022-10-14 13:56:21 +00:00
|
|
|
.validation
|
2022-10-18 13:19:03 +00:00
|
|
|
// FIXME: can't we get rid of the cloning here??
|
2022-10-11 14:50:54 +00:00
|
|
|
.validate(GenerateRequest {
|
2022-10-11 08:36:51 +00:00
|
|
|
inputs: req.inputs.clone(),
|
|
|
|
parameters: req.parameters.clone(),
|
|
|
|
})
|
2022-10-17 12:59:00 +00:00
|
|
|
.await?;
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Inference
|
2022-10-17 16:27:33 +00:00
|
|
|
let generated_text = state.batcher.infer(input_length, validated_request).await?;
|
2022-10-17 12:59:00 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Tracing metadata
|
2022-10-17 12:59:00 +00:00
|
|
|
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
|
|
|
tracing::Span::current().record(
|
|
|
|
"time_per_token",
|
|
|
|
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
|
|
|
|
);
|
|
|
|
tracing::info!("response: {}", generated_text);
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Send response
|
|
|
|
let response = vec![GeneratedText { generated_text }];
|
|
|
|
Ok(Json(response))
|
2022-10-11 08:36:51 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Serving method
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
|
|
pub async fn run(
|
|
|
|
max_concurrent_requests: usize,
|
|
|
|
max_input_length: usize,
|
|
|
|
max_batch_size: usize,
|
|
|
|
max_waiting_time: Duration,
|
|
|
|
client: ShardedClient,
|
|
|
|
tokenizer: Tokenizer,
|
|
|
|
validation_workers: usize,
|
|
|
|
addr: SocketAddr,
|
|
|
|
) {
|
|
|
|
// Create state
|
|
|
|
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
|
|
|
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
|
|
|
let shared_state = ServerState {
|
|
|
|
validation,
|
|
|
|
batcher,
|
|
|
|
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
|
|
|
};
|
|
|
|
|
|
|
|
// Create router
|
2022-10-14 13:56:21 +00:00
|
|
|
let app = Router::new()
|
|
|
|
.route("/generate", post(generate))
|
|
|
|
.layer(Extension(shared_state.clone()))
|
2022-10-18 13:19:03 +00:00
|
|
|
.route("/health", get(health))
|
2022-10-14 13:56:21 +00:00
|
|
|
.layer(Extension(shared_state.clone()));
|
2022-10-11 08:36:51 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Run server
|
2022-10-11 16:14:39 +00:00
|
|
|
axum::Server::bind(&addr)
|
2022-10-14 13:56:21 +00:00
|
|
|
.serve(app.into_make_service())
|
2022-10-18 13:19:03 +00:00
|
|
|
// Wait until all requests are finished to shut down
|
|
|
|
.with_graceful_shutdown(shutdown_signal())
|
2022-10-14 13:56:21 +00:00
|
|
|
.await
|
|
|
|
.unwrap();
|
2022-10-11 14:50:54 +00:00
|
|
|
}
|
2022-10-18 13:19:03 +00:00
|
|
|
|
|
|
|
/// Shutdown signal handler
|
|
|
|
async fn shutdown_signal() {
|
|
|
|
let ctrl_c = async {
|
|
|
|
signal::ctrl_c()
|
|
|
|
.await
|
|
|
|
.expect("failed to install Ctrl+C handler");
|
|
|
|
};
|
|
|
|
|
|
|
|
#[cfg(unix)]
|
|
|
|
let terminate = async {
|
|
|
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
|
|
|
.expect("failed to install signal handler")
|
|
|
|
.recv()
|
|
|
|
.await;
|
|
|
|
};
|
|
|
|
|
|
|
|
#[cfg(not(unix))]
|
|
|
|
let terminate = std::future::pending::<()>();
|
|
|
|
|
|
|
|
tokio::select! {
|
|
|
|
_ = ctrl_c => {},
|
|
|
|
_ = terminate => {},
|
|
|
|
}
|
|
|
|
|
|
|
|
tracing::info!("signal received, starting graceful shutdown");
|
|
|
|
}
|