2022-10-27 12:25:29 +00:00
|
|
|
use crate::{
|
2022-12-15 16:03:56 +00:00
|
|
|
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
|
2022-10-27 12:25:29 +00:00
|
|
|
};
|
2022-10-11 16:14:39 +00:00
|
|
|
use axum::extract::Extension;
|
2022-10-21 14:40:05 +00:00
|
|
|
use axum::http::{HeaderMap, StatusCode};
|
|
|
|
use axum::response::IntoResponse;
|
2022-10-15 18:21:50 +00:00
|
|
|
use axum::routing::{get, post};
|
2022-10-14 13:56:21 +00:00
|
|
|
use axum::{Json, Router};
|
|
|
|
use std::net::SocketAddr;
|
2022-10-18 13:19:03 +00:00
|
|
|
use std::sync::Arc;
|
2022-10-28 17:24:00 +00:00
|
|
|
use text_generation_client::ShardedClient;
|
2022-10-11 14:50:54 +00:00
|
|
|
use tokenizers::Tokenizer;
|
2022-10-18 13:19:03 +00:00
|
|
|
use tokio::signal;
|
|
|
|
use tokio::sync::Semaphore;
|
2022-10-11 08:36:51 +00:00
|
|
|
use tokio::time::Instant;
|
|
|
|
use tracing::instrument;
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Server shared state
|
|
|
|
#[derive(Clone)]
|
|
|
|
struct ServerState {
|
|
|
|
validation: Validation,
|
|
|
|
batcher: Batcher,
|
|
|
|
limit_concurrent_requests: Arc<Semaphore>,
|
2022-10-11 08:36:51 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Health check method
|
2022-10-14 13:56:21 +00:00
|
|
|
#[instrument(skip(state), fields(time, time_per_token))]
|
2022-10-27 12:25:29 +00:00
|
|
|
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
|
|
|
// be a bit too slow for a health check.
|
|
|
|
// What we should do instead if check if the gRPC channels are still healthy.
|
|
|
|
|
|
|
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
|
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
|
|
|
(
|
|
|
|
StatusCode::TOO_MANY_REQUESTS,
|
2022-10-27 12:25:29 +00:00
|
|
|
Json(ErrorResponse {
|
|
|
|
error: "Model is overloaded".to_string(),
|
|
|
|
}),
|
2022-10-18 13:19:03 +00:00
|
|
|
)
|
|
|
|
})?;
|
|
|
|
|
|
|
|
// Send a small inference request
|
2022-10-17 12:59:00 +00:00
|
|
|
state
|
2022-10-17 16:27:33 +00:00
|
|
|
.batcher
|
2022-10-14 13:56:21 +00:00
|
|
|
.infer(
|
|
|
|
1,
|
|
|
|
GenerateRequest {
|
|
|
|
inputs: "liveness".to_string(),
|
|
|
|
parameters: GenerateParameters {
|
|
|
|
temperature: 1.0,
|
|
|
|
top_k: 0,
|
|
|
|
top_p: 1.0,
|
|
|
|
do_sample: false,
|
|
|
|
max_new_tokens: 1,
|
2022-12-12 17:25:22 +00:00
|
|
|
stop: vec![],
|
2022-12-15 16:03:56 +00:00
|
|
|
details: false,
|
2022-10-14 13:56:21 +00:00
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
2022-10-17 12:59:00 +00:00
|
|
|
.await?;
|
|
|
|
Ok(())
|
2022-10-14 13:56:21 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate method
|
2022-10-21 14:40:05 +00:00
|
|
|
#[instrument(
|
|
|
|
skip(state),
|
|
|
|
fields(
|
|
|
|
total_time,
|
|
|
|
validation_time,
|
|
|
|
queue_time,
|
|
|
|
inference_time,
|
|
|
|
time_per_token
|
|
|
|
)
|
|
|
|
)]
|
2022-10-11 08:36:51 +00:00
|
|
|
async fn generate(
|
2022-10-11 16:14:39 +00:00
|
|
|
state: Extension<ServerState>,
|
2022-10-11 08:36:51 +00:00
|
|
|
req: Json<GenerateRequest>,
|
2022-10-27 12:25:29 +00:00
|
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
2022-10-21 14:40:05 +00:00
|
|
|
let start_time = Instant::now();
|
2022-10-18 13:19:03 +00:00
|
|
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
|
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
2022-10-27 12:25:29 +00:00
|
|
|
tracing::error!("Model is overloaded");
|
2022-10-18 13:19:03 +00:00
|
|
|
(
|
|
|
|
StatusCode::TOO_MANY_REQUESTS,
|
2022-10-27 12:25:29 +00:00
|
|
|
Json(ErrorResponse {
|
|
|
|
error: "Model is overloaded".to_string(),
|
|
|
|
}),
|
2022-10-18 13:19:03 +00:00
|
|
|
)
|
|
|
|
})?;
|
2022-10-11 08:36:51 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Validate request
|
2022-12-15 16:03:56 +00:00
|
|
|
let details = req.0.parameters.details;
|
2022-12-12 17:25:22 +00:00
|
|
|
let (input_length, validated_request) =
|
|
|
|
state.validation.validate(req.0).await.map_err(|err| {
|
2022-10-27 12:25:29 +00:00
|
|
|
tracing::error!("{}", err.to_string());
|
|
|
|
err
|
|
|
|
})?;
|
2022-10-17 12:59:00 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Inference
|
2022-10-27 12:25:29 +00:00
|
|
|
let response = state
|
|
|
|
.batcher
|
|
|
|
.infer(input_length, validated_request)
|
|
|
|
.await
|
|
|
|
.map_err(|err| {
|
|
|
|
tracing::error!("{}", err.to_string());
|
|
|
|
err
|
|
|
|
})?;
|
2022-10-17 12:59:00 +00:00
|
|
|
|
2022-12-15 16:03:56 +00:00
|
|
|
// Token details
|
|
|
|
let details = match details {
|
|
|
|
true => {
|
|
|
|
let tokens = response
|
|
|
|
.token_ids
|
|
|
|
.into_iter()
|
|
|
|
.zip(response.tokens.into_iter())
|
|
|
|
.zip(response.logprobs.into_iter())
|
|
|
|
.map(|((id, text), logprob)| (id, text, logprob))
|
|
|
|
.collect();
|
|
|
|
Some(Details {
|
|
|
|
finish_reason: response.finish_reason,
|
|
|
|
generated_tokens: response.generated_tokens,
|
|
|
|
tokens,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
false => None,
|
|
|
|
};
|
|
|
|
|
2022-10-21 14:40:05 +00:00
|
|
|
// Timings
|
|
|
|
let total_time = start_time.elapsed();
|
|
|
|
let validation_time = response.queued - start_time;
|
|
|
|
let queue_time = response.start - response.queued;
|
|
|
|
let inference_time = response.end - response.start;
|
2022-12-15 16:03:56 +00:00
|
|
|
let time_per_token = inference_time / response.generated_tokens;
|
2022-10-21 14:40:05 +00:00
|
|
|
|
|
|
|
// Headers
|
|
|
|
let mut headers = HeaderMap::new();
|
|
|
|
headers.insert(
|
|
|
|
"x-total-time",
|
|
|
|
total_time.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
headers.insert(
|
|
|
|
"x-validation-time",
|
|
|
|
validation_time.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
headers.insert(
|
|
|
|
"x-queue-time",
|
|
|
|
queue_time.as_millis().to_string().parse().unwrap(),
|
2022-10-17 12:59:00 +00:00
|
|
|
);
|
2022-10-21 14:40:05 +00:00
|
|
|
headers.insert(
|
|
|
|
"x-inference-time",
|
|
|
|
inference_time.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
headers.insert(
|
|
|
|
"x-time-per-token",
|
|
|
|
time_per_token.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
|
|
|
|
// Tracing metadata
|
|
|
|
tracing::Span::current().record("total_time", format!("{:?}", total_time));
|
|
|
|
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
|
|
|
|
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
|
|
|
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
|
|
|
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
2022-12-15 16:03:56 +00:00
|
|
|
tracing::info!("Output: {}", response.output_text);
|
2022-10-17 12:59:00 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Send response
|
2022-10-21 14:40:05 +00:00
|
|
|
let response = vec![GeneratedText {
|
2022-12-15 16:03:56 +00:00
|
|
|
generated_text: response.output_text,
|
|
|
|
details,
|
2022-10-21 14:40:05 +00:00
|
|
|
}];
|
|
|
|
Ok((headers, Json(response)))
|
2022-10-11 08:36:51 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Serving method
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
|
|
pub async fn run(
|
|
|
|
max_concurrent_requests: usize,
|
|
|
|
max_input_length: usize,
|
|
|
|
max_batch_size: usize,
|
2022-10-21 14:40:05 +00:00
|
|
|
max_waiting_tokens: usize,
|
2022-10-18 13:19:03 +00:00
|
|
|
client: ShardedClient,
|
|
|
|
tokenizer: Tokenizer,
|
|
|
|
validation_workers: usize,
|
|
|
|
addr: SocketAddr,
|
|
|
|
) {
|
|
|
|
// Create state
|
2022-10-21 14:40:05 +00:00
|
|
|
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
2022-10-18 13:19:03 +00:00
|
|
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
|
|
|
let shared_state = ServerState {
|
|
|
|
validation,
|
|
|
|
batcher,
|
|
|
|
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
|
|
|
};
|
|
|
|
|
|
|
|
// Create router
|
2022-10-14 13:56:21 +00:00
|
|
|
let app = Router::new()
|
|
|
|
.route("/generate", post(generate))
|
|
|
|
.layer(Extension(shared_state.clone()))
|
2023-01-23 16:11:27 +00:00
|
|
|
.route("/", get(health))
|
2022-10-18 13:19:03 +00:00
|
|
|
.route("/health", get(health))
|
2022-10-14 13:56:21 +00:00
|
|
|
.layer(Extension(shared_state.clone()));
|
2022-10-11 08:36:51 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Run server
|
2022-10-11 16:14:39 +00:00
|
|
|
axum::Server::bind(&addr)
|
2022-10-14 13:56:21 +00:00
|
|
|
.serve(app.into_make_service())
|
2022-10-18 13:19:03 +00:00
|
|
|
// Wait until all requests are finished to shut down
|
|
|
|
.with_graceful_shutdown(shutdown_signal())
|
2022-10-14 13:56:21 +00:00
|
|
|
.await
|
|
|
|
.unwrap();
|
2022-10-11 14:50:54 +00:00
|
|
|
}
|
2022-10-18 13:19:03 +00:00
|
|
|
|
|
|
|
/// Shutdown signal handler
|
|
|
|
async fn shutdown_signal() {
|
|
|
|
let ctrl_c = async {
|
|
|
|
signal::ctrl_c()
|
|
|
|
.await
|
|
|
|
.expect("failed to install Ctrl+C handler");
|
|
|
|
};
|
|
|
|
|
|
|
|
#[cfg(unix)]
|
|
|
|
let terminate = async {
|
|
|
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
|
|
|
.expect("failed to install signal handler")
|
|
|
|
.recv()
|
|
|
|
.await;
|
|
|
|
};
|
|
|
|
|
|
|
|
#[cfg(not(unix))]
|
|
|
|
let terminate = std::future::pending::<()>();
|
|
|
|
|
|
|
|
tokio::select! {
|
|
|
|
_ = ctrl_c => {},
|
|
|
|
_ = terminate => {},
|
|
|
|
}
|
|
|
|
|
|
|
|
tracing::info!("signal received, starting graceful shutdown");
|
|
|
|
}
|