text-generation-inference/router/src/server.rs

141 lines
3.6 KiB
Rust
Raw Normal View History

2022-10-15 18:21:50 +00:00
use crate::{Batcher, Validation};
2022-10-11 16:14:39 +00:00
use axum::extract::Extension;
2022-10-14 13:56:21 +00:00
use axum::http::StatusCode;
2022-10-15 18:21:50 +00:00
use axum::routing::{get, post};
2022-10-14 13:56:21 +00:00
use axum::{Json, Router};
2022-10-17 12:59:00 +00:00
use bloom_inference_client::ShardedClient;
use serde::Deserialize;
2022-10-14 13:56:21 +00:00
use std::net::SocketAddr;
use tokenizers::Tokenizer;
2022-10-11 08:36:51 +00:00
use tokio::time::Instant;
use tracing::instrument;
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_k")]
2022-10-17 12:59:00 +00:00
pub top_k: i32,
2022-10-11 08:36:51 +00:00
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_do_sample")]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
2022-10-17 12:59:00 +00:00
fn default_top_k() -> i32 {
2022-10-11 08:36:51 +00:00
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
}
2022-10-14 13:56:21 +00:00
#[instrument(skip(state), fields(time, time_per_token))]
2022-10-17 12:59:00 +00:00
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
state
2022-10-14 13:56:21 +00:00
.infer
.infer(
1,
GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
do_sample: false,
max_new_tokens: 1,
},
},
)
2022-10-17 12:59:00 +00:00
.await?;
Ok(())
2022-10-14 13:56:21 +00:00
}
2022-10-11 16:14:39 +00:00
#[instrument(skip(state), fields(time, time_per_token))]
2022-10-11 08:36:51 +00:00
async fn generate(
2022-10-11 16:14:39 +00:00
state: Extension<ServerState>,
2022-10-11 08:36:51 +00:00
req: Json<GenerateRequest>,
2022-10-17 12:59:00 +00:00
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
2022-10-11 08:36:51 +00:00
let start = Instant::now();
2022-10-17 12:59:00 +00:00
let (input_length, validated_request) = state
2022-10-14 13:56:21 +00:00
.validation
.validate(GenerateRequest {
2022-10-11 08:36:51 +00:00
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
2022-10-17 12:59:00 +00:00
.await?;
let generated_text = state.infer.infer(input_length, validated_request).await?;
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record(
"time_per_token",
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
);
tracing::info!("response: {}", generated_text);
Ok(Json(serde_json::json!({
"generated_text": generated_text,
})))
2022-10-11 08:36:51 +00:00
}
2022-10-11 16:14:39 +00:00
#[derive(Clone)]
struct ServerState {
validation: Validation,
infer: Batcher,
}
2022-10-14 13:56:21 +00:00
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
client.clear_cache().await.expect("Unable to clear cache");
2022-10-11 08:36:51 +00:00
tracing::info!("Connected");
let infer = Batcher::new(client);
let validation = Validation::new(tokenizer);
2022-10-14 13:56:21 +00:00
let shared_state = ServerState { validation, infer };
2022-10-11 16:14:39 +00:00
2022-10-14 13:56:21 +00:00
let app = Router::new()
.route("/generate", post(generate))
.layer(Extension(shared_state.clone()))
2022-10-15 18:21:50 +00:00
.route("/health", get(liveness))
2022-10-14 13:56:21 +00:00
.layer(Extension(shared_state.clone()));
2022-10-11 08:36:51 +00:00
2022-10-11 16:14:39 +00:00
axum::Server::bind(&addr)
2022-10-14 13:56:21 +00:00
.serve(app.into_make_service())
.await
.unwrap();
}