text-generation-inference/router/src/server.rs

239 lines
7.0 KiB
Rust
Raw Normal View History

2022-10-27 12:25:29 +00:00
use crate::{
2022-12-15 16:03:56 +00:00
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
2022-10-27 12:25:29 +00:00
};
2022-10-11 16:14:39 +00:00
use axum::extract::Extension;
2022-10-21 14:40:05 +00:00
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
2022-10-15 18:21:50 +00:00
use axum::routing::{get, post};
2022-10-14 13:56:21 +00:00
use axum::{Json, Router};
use std::net::SocketAddr;
2022-10-18 13:19:03 +00:00
use std::sync::Arc;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
2022-10-18 13:19:03 +00:00
use tokio::signal;
use tokio::sync::Semaphore;
2022-10-11 08:36:51 +00:00
use tokio::time::Instant;
use tracing::instrument;
2022-10-18 13:19:03 +00:00
// Server shared state
#[derive(Clone)]
struct ServerState {
validation: Validation,
batcher: Batcher,
limit_concurrent_requests: Arc<Semaphore>,
2022-10-11 08:36:51 +00:00
}
2022-10-18 13:19:03 +00:00
/// Health check method
2022-10-14 13:56:21 +00:00
#[instrument(skip(state), fields(time, time_per_token))]
2022-10-27 12:25:29 +00:00
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
2022-10-18 13:19:03 +00:00
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
2022-10-27 12:25:29 +00:00
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
2022-10-18 13:19:03 +00:00
)
})?;
// Send a small inference request
2022-10-17 12:59:00 +00:00
state
2022-10-17 16:27:33 +00:00
.batcher
2022-10-14 13:56:21 +00:00
.infer(
1,
GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
do_sample: false,
max_new_tokens: 1,
2022-12-12 17:25:22 +00:00
stop: vec![],
2022-12-15 16:03:56 +00:00
details: false,
2022-10-14 13:56:21 +00:00
},
},
)
2022-10-17 12:59:00 +00:00
.await?;
Ok(())
2022-10-14 13:56:21 +00:00
}
2022-10-18 13:19:03 +00:00
/// Generate method
2022-10-21 14:40:05 +00:00
#[instrument(
skip(state),
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token
)
)]
2022-10-11 08:36:51 +00:00
async fn generate(
2022-10-11 16:14:39 +00:00
state: Extension<ServerState>,
2022-10-11 08:36:51 +00:00
req: Json<GenerateRequest>,
2022-10-27 12:25:29 +00:00
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
2022-10-21 14:40:05 +00:00
let start_time = Instant::now();
2022-10-18 13:19:03 +00:00
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
2022-10-27 12:25:29 +00:00
tracing::error!("Model is overloaded");
2022-10-18 13:19:03 +00:00
(
StatusCode::TOO_MANY_REQUESTS,
2022-10-27 12:25:29 +00:00
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
2022-10-18 13:19:03 +00:00
)
})?;
2022-10-11 08:36:51 +00:00
2022-10-18 13:19:03 +00:00
// Validate request
2022-12-15 16:03:56 +00:00
let details = req.0.parameters.details;
2022-12-12 17:25:22 +00:00
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
2022-10-27 12:25:29 +00:00
tracing::error!("{}", err.to_string());
err
})?;
2022-10-17 12:59:00 +00:00
2022-10-18 13:19:03 +00:00
// Inference
2022-10-27 12:25:29 +00:00
let response = state
.batcher
.infer(input_length, validated_request)
.await
.map_err(|err| {
tracing::error!("{}", err.to_string());
err
})?;
2022-10-17 12:59:00 +00:00
2022-12-15 16:03:56 +00:00
// Token details
let details = match details {
true => {
let tokens = response
.token_ids
.into_iter()
.zip(response.tokens.into_iter())
.zip(response.logprobs.into_iter())
.map(|((id, text), logprob)| (id, text, logprob))
.collect();
Some(Details {
finish_reason: response.finish_reason,
generated_tokens: response.generated_tokens,
tokens,
})
}
false => None,
};
2022-10-21 14:40:05 +00:00
// Timings
let total_time = start_time.elapsed();
let validation_time = response.queued - start_time;
let queue_time = response.start - response.queued;
let inference_time = response.end - response.start;
2022-12-15 16:03:56 +00:00
let time_per_token = inference_time / response.generated_tokens;
2022-10-21 14:40:05 +00:00
// Headers
let mut headers = HeaderMap::new();
headers.insert(
"x-total-time",
total_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-validation-time",
validation_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-queue-time",
queue_time.as_millis().to_string().parse().unwrap(),
2022-10-17 12:59:00 +00:00
);
2022-10-21 14:40:05 +00:00
headers.insert(
"x-inference-time",
inference_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-time-per-token",
time_per_token.as_millis().to_string().parse().unwrap(),
);
// Tracing metadata
tracing::Span::current().record("total_time", format!("{:?}", total_time));
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
2022-12-15 16:03:56 +00:00
tracing::info!("Output: {}", response.output_text);
2022-10-17 12:59:00 +00:00
2022-10-18 13:19:03 +00:00
// Send response
2022-10-21 14:40:05 +00:00
let response = vec![GeneratedText {
2022-12-15 16:03:56 +00:00
generated_text: response.output_text,
details,
2022-10-21 14:40:05 +00:00
}];
Ok((headers, Json(response)))
2022-10-11 08:36:51 +00:00
}
2022-10-18 13:19:03 +00:00
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
max_concurrent_requests: usize,
max_input_length: usize,
max_batch_size: usize,
2022-10-21 14:40:05 +00:00
max_waiting_tokens: usize,
2022-10-18 13:19:03 +00:00
client: ShardedClient,
tokenizer: Tokenizer,
validation_workers: usize,
addr: SocketAddr,
) {
// Create state
2022-10-21 14:40:05 +00:00
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
2022-10-18 13:19:03 +00:00
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState {
validation,
batcher,
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
};
// Create router
2022-10-14 13:56:21 +00:00
let app = Router::new()
.route("/generate", post(generate))
.layer(Extension(shared_state.clone()))
.route("/", get(health))
2022-10-18 13:19:03 +00:00
.route("/health", get(health))
2022-10-14 13:56:21 +00:00
.layer(Extension(shared_state.clone()));
2022-10-11 08:36:51 +00:00
2022-10-18 13:19:03 +00:00
// Run server
2022-10-11 16:14:39 +00:00
axum::Server::bind(&addr)
2022-10-14 13:56:21 +00:00
.serve(app.into_make_service())
2022-10-18 13:19:03 +00:00
// Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
2022-10-14 13:56:21 +00:00
.await
.unwrap();
}
2022-10-18 13:19:03 +00:00
/// Shutdown signal handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
tracing::info!("signal received, starting graceful shutdown");
}