2023-01-31 16:04:00 +00:00
|
|
|
/// HTTP Server logic
|
|
|
|
use crate::infer::{InferError, InferStreamResponse};
|
2022-10-27 12:25:29 +00:00
|
|
|
use crate::{
|
2023-01-31 16:04:00 +00:00
|
|
|
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer,
|
|
|
|
StreamResponse, Validation,
|
2022-10-27 12:25:29 +00:00
|
|
|
};
|
2022-10-11 16:14:39 +00:00
|
|
|
use axum::extract::Extension;
|
2022-10-21 14:40:05 +00:00
|
|
|
use axum::http::{HeaderMap, StatusCode};
|
2023-01-31 16:04:00 +00:00
|
|
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
2022-10-21 14:40:05 +00:00
|
|
|
use axum::response::IntoResponse;
|
2022-10-15 18:21:50 +00:00
|
|
|
use axum::routing::{get, post};
|
2022-10-14 13:56:21 +00:00
|
|
|
use axum::{Json, Router};
|
2023-01-31 16:04:00 +00:00
|
|
|
use futures::Stream;
|
|
|
|
use std::convert::Infallible;
|
2022-10-14 13:56:21 +00:00
|
|
|
use std::net::SocketAddr;
|
2022-10-28 17:24:00 +00:00
|
|
|
use text_generation_client::ShardedClient;
|
2022-10-11 14:50:54 +00:00
|
|
|
use tokenizers::Tokenizer;
|
2022-10-18 13:19:03 +00:00
|
|
|
use tokio::signal;
|
2022-10-11 08:36:51 +00:00
|
|
|
use tokio::time::Instant;
|
2023-01-31 16:04:00 +00:00
|
|
|
use tokio_stream::StreamExt;
|
2022-10-11 08:36:51 +00:00
|
|
|
use tracing::instrument;
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Health check method
|
2023-01-31 16:04:00 +00:00
|
|
|
#[instrument(skip(infer))]
|
|
|
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
|
|
|
// be a bit too slow for a health check.
|
|
|
|
// What we should do instead if check if the gRPC channels are still healthy.
|
|
|
|
|
|
|
|
// Send a small inference request
|
2023-01-31 16:04:00 +00:00
|
|
|
infer
|
|
|
|
.generate(GenerateRequest {
|
|
|
|
inputs: "liveness".to_string(),
|
|
|
|
parameters: GenerateParameters {
|
|
|
|
temperature: 1.0,
|
2023-02-01 14:58:42 +00:00
|
|
|
repetition_penalty: 1.0,
|
2023-01-31 16:04:00 +00:00
|
|
|
top_k: 0,
|
|
|
|
top_p: 1.0,
|
|
|
|
do_sample: false,
|
|
|
|
max_new_tokens: 1,
|
|
|
|
stop: vec![],
|
|
|
|
details: false,
|
|
|
|
seed: None,
|
2022-10-14 13:56:21 +00:00
|
|
|
},
|
2023-01-31 16:04:00 +00:00
|
|
|
})
|
2022-10-17 12:59:00 +00:00
|
|
|
.await?;
|
|
|
|
Ok(())
|
2022-10-14 13:56:21 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate method
|
2022-10-21 14:40:05 +00:00
|
|
|
#[instrument(
|
2023-01-31 16:04:00 +00:00
|
|
|
skip(infer),
|
2022-10-21 14:40:05 +00:00
|
|
|
fields(
|
|
|
|
total_time,
|
|
|
|
validation_time,
|
|
|
|
queue_time,
|
|
|
|
inference_time,
|
2023-01-30 14:36:16 +00:00
|
|
|
time_per_token,
|
|
|
|
seed
|
2022-10-21 14:40:05 +00:00
|
|
|
)
|
|
|
|
)]
|
2022-10-11 08:36:51 +00:00
|
|
|
async fn generate(
|
2023-01-31 16:04:00 +00:00
|
|
|
infer: Extension<Infer>,
|
2022-10-11 08:36:51 +00:00
|
|
|
req: Json<GenerateRequest>,
|
2022-10-27 12:25:29 +00:00
|
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
2023-01-31 16:04:00 +00:00
|
|
|
let span = tracing::Span::current();
|
2022-10-21 14:40:05 +00:00
|
|
|
let start_time = Instant::now();
|
2023-01-31 13:21:51 +00:00
|
|
|
|
|
|
|
// Inference
|
2023-01-31 16:04:00 +00:00
|
|
|
let details = req.0.parameters.details;
|
|
|
|
let response = infer.generate(req.0).await.map_err(|err| {
|
|
|
|
tracing::error!("{}", err.to_string());
|
|
|
|
err
|
|
|
|
})?;
|
2022-10-17 12:59:00 +00:00
|
|
|
|
2022-12-15 16:03:56 +00:00
|
|
|
// Token details
|
|
|
|
let details = match details {
|
2023-01-31 16:04:00 +00:00
|
|
|
true => Some(Details {
|
|
|
|
finish_reason: response.generated_text.finish_reason,
|
|
|
|
generated_tokens: response.generated_text.generated_tokens,
|
|
|
|
prefill: Some(response.prefill),
|
|
|
|
tokens: Some(response.tokens),
|
|
|
|
seed: response.generated_text.seed,
|
|
|
|
}),
|
2022-12-15 16:03:56 +00:00
|
|
|
false => None,
|
|
|
|
};
|
|
|
|
|
2022-10-21 14:40:05 +00:00
|
|
|
// Timings
|
|
|
|
let total_time = start_time.elapsed();
|
|
|
|
let validation_time = response.queued - start_time;
|
|
|
|
let queue_time = response.start - response.queued;
|
2023-01-31 16:04:00 +00:00
|
|
|
let inference_time = Instant::now() - response.start;
|
|
|
|
let time_per_token = inference_time / response.generated_text.generated_tokens;
|
2022-10-21 14:40:05 +00:00
|
|
|
|
|
|
|
// Headers
|
|
|
|
let mut headers = HeaderMap::new();
|
|
|
|
headers.insert(
|
|
|
|
"x-total-time",
|
|
|
|
total_time.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
headers.insert(
|
|
|
|
"x-validation-time",
|
|
|
|
validation_time.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
headers.insert(
|
|
|
|
"x-queue-time",
|
|
|
|
queue_time.as_millis().to_string().parse().unwrap(),
|
2022-10-17 12:59:00 +00:00
|
|
|
);
|
2022-10-21 14:40:05 +00:00
|
|
|
headers.insert(
|
|
|
|
"x-inference-time",
|
|
|
|
inference_time.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
headers.insert(
|
|
|
|
"x-time-per-token",
|
|
|
|
time_per_token.as_millis().to_string().parse().unwrap(),
|
|
|
|
);
|
|
|
|
|
|
|
|
// Tracing metadata
|
2023-01-31 16:04:00 +00:00
|
|
|
span.record("total_time", format!("{:?}", total_time));
|
|
|
|
span.record("validation_time", format!("{:?}", validation_time));
|
|
|
|
span.record("queue_time", format!("{:?}", queue_time));
|
|
|
|
span.record("inference_time", format!("{:?}", inference_time));
|
|
|
|
span.record("time_per_token", format!("{:?}", time_per_token));
|
|
|
|
span.record("seed", format!("{:?}", response.generated_text.seed));
|
|
|
|
tracing::info!("Output: {}", response.generated_text.text);
|
2022-10-17 12:59:00 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Send response
|
2023-01-31 16:04:00 +00:00
|
|
|
let response = vec![GenerateResponse {
|
|
|
|
generated_text: response.generated_text.text,
|
2022-12-15 16:03:56 +00:00
|
|
|
details,
|
2022-10-21 14:40:05 +00:00
|
|
|
}];
|
|
|
|
Ok((headers, Json(response)))
|
2022-10-11 08:36:51 +00:00
|
|
|
}
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Generate stream method
|
|
|
|
#[instrument(
|
|
|
|
skip(infer),
|
|
|
|
fields(
|
|
|
|
total_time,
|
|
|
|
validation_time,
|
|
|
|
queue_time,
|
|
|
|
inference_time,
|
|
|
|
time_per_token
|
|
|
|
)
|
|
|
|
)]
|
|
|
|
async fn generate_stream(
|
|
|
|
infer: Extension<Infer>,
|
|
|
|
req: Json<GenerateRequest>,
|
|
|
|
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
|
|
|
let span = tracing::Span::current();
|
|
|
|
let start_time = Instant::now();
|
|
|
|
|
|
|
|
let stream = async_stream::stream! {
|
|
|
|
// Inference
|
|
|
|
let mut end_reached = false;
|
|
|
|
let mut error = false;
|
|
|
|
let details = req.0.parameters.details;
|
|
|
|
|
|
|
|
match infer.generate_stream(req.0).await {
|
|
|
|
Ok(mut response_stream) => {
|
|
|
|
// Server Side Event stream
|
|
|
|
while let Some(response) = response_stream.next().await {
|
|
|
|
match response {
|
|
|
|
Ok(response) => {
|
|
|
|
match response {
|
|
|
|
// Prefill is ignored
|
|
|
|
InferStreamResponse::Prefill(_) => {}
|
|
|
|
// Yield event for every new token
|
|
|
|
InferStreamResponse::Token(token) => {
|
|
|
|
// StreamResponse
|
|
|
|
let stream_token = StreamResponse {
|
|
|
|
token,
|
|
|
|
generated_text: None,
|
|
|
|
details: None,
|
|
|
|
};
|
|
|
|
|
|
|
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
|
|
|
}
|
|
|
|
// Yield event for last token and compute timings
|
|
|
|
InferStreamResponse::End {
|
|
|
|
token,
|
|
|
|
generated_text,
|
|
|
|
start,
|
|
|
|
queued,
|
|
|
|
} => {
|
|
|
|
// Token details
|
|
|
|
let details = match details {
|
|
|
|
true => Some(Details {
|
|
|
|
finish_reason: generated_text.finish_reason,
|
|
|
|
generated_tokens: generated_text.generated_tokens,
|
|
|
|
prefill: None,
|
|
|
|
tokens: None,
|
|
|
|
seed: generated_text.seed,
|
|
|
|
}),
|
|
|
|
false => None,
|
|
|
|
};
|
|
|
|
|
|
|
|
// Timings
|
|
|
|
let total_time = start_time.elapsed();
|
|
|
|
let validation_time = queued - start_time;
|
|
|
|
let queue_time = start - queued;
|
|
|
|
let inference_time = Instant::now() - start;
|
|
|
|
let time_per_token = inference_time / generated_text.generated_tokens;
|
|
|
|
|
|
|
|
// Tracing metadata
|
|
|
|
span.record("total_time", format!("{:?}", total_time));
|
|
|
|
span
|
|
|
|
.record("validation_time", format!("{:?}", validation_time));
|
|
|
|
span.record("queue_time", format!("{:?}", queue_time));
|
|
|
|
span
|
|
|
|
.record("inference_time", format!("{:?}", inference_time));
|
|
|
|
span
|
|
|
|
.record("time_per_token", format!("{:?}", time_per_token));
|
|
|
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
|
|
|
|
|
|
|
// StreamResponse
|
|
|
|
end_reached = true;
|
|
|
|
let stream_token = StreamResponse {
|
|
|
|
token,
|
|
|
|
generated_text: Some(generated_text.text),
|
|
|
|
details
|
|
|
|
};
|
|
|
|
|
|
|
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Trace and yield error
|
|
|
|
Err(err) => {
|
|
|
|
error = true;
|
|
|
|
tracing::error!("{}", err.to_string());
|
|
|
|
yield Ok(Event::from(err))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
// Trace and yield error
|
|
|
|
Err(err) => {
|
|
|
|
error = true;
|
|
|
|
tracing::error!("{}", err.to_string());
|
|
|
|
yield Ok(Event::from(err))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Check if generation reached the end
|
|
|
|
// Skip if we already sent an error
|
|
|
|
if !end_reached && !error {
|
|
|
|
let err = InferError::IncompleteGeneration;
|
|
|
|
tracing::error!("{}", err.to_string());
|
|
|
|
yield Ok(Event::from(err))
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Serving method
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
|
|
pub async fn run(
|
|
|
|
max_concurrent_requests: usize,
|
|
|
|
max_input_length: usize,
|
|
|
|
max_batch_size: usize,
|
2022-10-21 14:40:05 +00:00
|
|
|
max_waiting_tokens: usize,
|
2022-10-18 13:19:03 +00:00
|
|
|
client: ShardedClient,
|
|
|
|
tokenizer: Tokenizer,
|
|
|
|
validation_workers: usize,
|
|
|
|
addr: SocketAddr,
|
|
|
|
) {
|
|
|
|
// Create state
|
|
|
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
2023-01-31 16:04:00 +00:00
|
|
|
let infer = Infer::new(
|
|
|
|
client,
|
2022-10-18 13:19:03 +00:00
|
|
|
validation,
|
2023-01-31 16:04:00 +00:00
|
|
|
max_batch_size,
|
|
|
|
max_waiting_tokens,
|
|
|
|
max_concurrent_requests,
|
|
|
|
);
|
2022-10-18 13:19:03 +00:00
|
|
|
|
|
|
|
// Create router
|
2022-10-14 13:56:21 +00:00
|
|
|
let app = Router::new()
|
2023-01-23 16:42:14 +00:00
|
|
|
.route("/", post(generate))
|
2022-10-14 13:56:21 +00:00
|
|
|
.route("/generate", post(generate))
|
2023-01-31 16:04:00 +00:00
|
|
|
.route("/generate_stream", post(generate_stream))
|
2023-01-23 16:11:27 +00:00
|
|
|
.route("/", get(health))
|
2022-10-18 13:19:03 +00:00
|
|
|
.route("/health", get(health))
|
2023-01-31 16:04:00 +00:00
|
|
|
.layer(Extension(infer));
|
2022-10-11 08:36:51 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Run server
|
2022-10-11 16:14:39 +00:00
|
|
|
axum::Server::bind(&addr)
|
2022-10-14 13:56:21 +00:00
|
|
|
.serve(app.into_make_service())
|
2022-10-18 13:19:03 +00:00
|
|
|
// Wait until all requests are finished to shut down
|
|
|
|
.with_graceful_shutdown(shutdown_signal())
|
2022-10-14 13:56:21 +00:00
|
|
|
.await
|
|
|
|
.unwrap();
|
2022-10-11 14:50:54 +00:00
|
|
|
}
|
2022-10-18 13:19:03 +00:00
|
|
|
|
|
|
|
/// Shutdown signal handler
|
|
|
|
async fn shutdown_signal() {
|
|
|
|
let ctrl_c = async {
|
|
|
|
signal::ctrl_c()
|
|
|
|
.await
|
|
|
|
.expect("failed to install Ctrl+C handler");
|
|
|
|
};
|
|
|
|
|
|
|
|
#[cfg(unix)]
|
|
|
|
let terminate = async {
|
|
|
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
|
|
|
.expect("failed to install signal handler")
|
|
|
|
.recv()
|
|
|
|
.await;
|
|
|
|
};
|
|
|
|
|
|
|
|
#[cfg(not(unix))]
|
|
|
|
let terminate = std::future::pending::<()>();
|
|
|
|
|
|
|
|
tokio::select! {
|
|
|
|
_ = ctrl_c => {},
|
|
|
|
_ = terminate => {},
|
|
|
|
}
|
|
|
|
|
|
|
|
tracing::info!("signal received, starting graceful shutdown");
|
|
|
|
}
|
2023-01-31 16:04:00 +00:00
|
|
|
|
|
|
|
/// Convert to Axum supported formats
|
|
|
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|
|
|
fn from(err: InferError) -> Self {
|
|
|
|
let status_code = match err {
|
|
|
|
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
|
|
|
|
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
|
|
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
|
|
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
};
|
|
|
|
|
|
|
|
(
|
|
|
|
status_code,
|
|
|
|
Json(ErrorResponse {
|
|
|
|
error: err.to_string(),
|
|
|
|
}),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl From<InferError> for Event {
|
|
|
|
fn from(err: InferError) -> Self {
|
|
|
|
Event::default()
|
|
|
|
.json_data(ErrorResponse {
|
|
|
|
error: err.to_string(),
|
|
|
|
})
|
|
|
|
.unwrap()
|
|
|
|
}
|
|
|
|
}
|