Improved version

This commit is contained in:
OlivierDehaene 2023-01-30 10:55:54 +01:00
parent 122c137b56
commit 429155a26a
6 changed files with 166 additions and 83 deletions

View File

@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies] [dev-dependencies]
float_eq = "1.0.1" float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] } reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde = "1.0.150" serde = { version = "1.0.150", features = ["derive"] }

View File

@ -8,7 +8,7 @@ mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request, Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, StoppingCriteriaParameters, PrefillTokens
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -10,6 +10,7 @@ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::OwnedSemaphorePermit;
use tokio::time::Instant; use tokio::time::Instant;
/// Database entry /// Database entry
@ -25,6 +26,8 @@ pub(crate) struct Entry {
pub time: Instant, pub time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Permit
pub _permit: OwnedSemaphorePermit,
} }
/// Request Database /// Request Database

View File

@ -5,7 +5,7 @@ use crate::{Db, Entry, Token};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient}; use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
@ -22,12 +22,12 @@ pub struct Infer {
db: Db, db: Db,
/// Shared state /// Shared state
shared: Arc<Shared>, shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
} }
/// Infer shared state /// Infer shared state
struct Shared { struct Shared {
/// Inference limit
limit_concurrent_requests: Semaphore,
/// Batching background Tokio task notifier /// Batching background Tokio task notifier
batching_task: Notify, batching_task: Notify,
} }
@ -43,7 +43,6 @@ impl Infer {
// Infer shared state // Infer shared state
let db = Db::new(); let db = Db::new();
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
batching_task: Notify::new(), batching_task: Notify::new(),
}); });
@ -56,10 +55,14 @@ impl Infer {
shared.clone(), shared.clone(),
)); ));
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
Self { Self {
validation, validation,
db, db,
shared, shared,
limit_concurrent_requests: semaphore,
} }
} }
@ -69,7 +72,8 @@ impl Infer {
request: GenerateRequest, request: GenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self.shared.limit_concurrent_requests.try_acquire()?; // This permit will live as long as Entry
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
// Validate request // Validate request
let (input_length, validated_request) = self.validation.validate(request).await?; let (input_length, validated_request) = self.validation.validate(request).await?;
@ -84,6 +88,7 @@ impl Infer {
input_length, input_length,
time: Instant::now(), time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit
}); });
// Notify the background task that we have a new entry in the database that needs // Notify the background task that we have a new entry in the database that needs
@ -113,30 +118,48 @@ impl Infer {
match response? { match response? {
// Add prefill tokens // Add prefill tokens
InferStreamResponse::Prefill(prefill_tokens) => { InferStreamResponse::Prefill(prefill_tokens) => {
result_tokens.extend(prefill_tokens) // Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.collect();
result_tokens = prefill_tokens;
} }
// Push last token // Push last token
InferStreamResponse::Token(token) => result_tokens.push(token), InferStreamResponse::Token(token) => result_tokens.push(token),
// Final message // Final message
// Set return values // Set return values
InferStreamResponse::End { InferStreamResponse::End {
token,
generated_text, generated_text,
start, start,
queued, queued,
} => { } => {
result_tokens.push(token);
result_generated_text = Some(generated_text); result_generated_text = Some(generated_text);
result_start = Some(start); result_start = Some(start);
result_queued = Some(queued) result_queued = Some(queued)
} }
} }
} }
// Unwrap is safe here
Ok(InferResponse { // Check that we received a `InferStreamResponse::End` message
tokens: result_tokens, if let (Some(generated_text), Some(queued), Some(start)) =
generated_text: result_generated_text.unwrap(), (result_generated_text, result_queued, result_start)
queued: result_queued.unwrap(), {
start: result_start.unwrap(), Ok(InferResponse {
}) tokens: result_tokens,
generated_text,
queued,
start,
})
} else {
Err(InferError::IncompleteGeneration)
}
} }
} }
@ -210,7 +233,7 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the responses to Infer /// Wrap a future inside a match statement to handle errors and send the responses to Infer
async fn wrap_future( async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>, future: impl Future<Output=Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<Batch> {
match future.await { match future.await {
@ -247,20 +270,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.collect();
// Send message // Send message
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry entry
.response_tx .response_tx
.send(Ok(InferStreamResponse::Prefill(tokens))) .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
.unwrap_or(()); .unwrap_or(());
} }
@ -271,13 +285,6 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
generation.token_logprob, generation.token_logprob,
); );
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message // Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries // We can `expect` here as the request id should always be in the entries
@ -290,11 +297,19 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
entry entry
.response_tx .response_tx
.send(Ok(InferStreamResponse::End { .send(Ok(InferStreamResponse::End {
token,
generated_text, generated_text,
queued: entry.time, queued: entry.time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
})) }))
.unwrap_or(()); .unwrap_or(());
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
} }
}); });
} }
@ -302,11 +317,12 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(Vec<Token>), Prefill(PrefillTokens),
// Intermediate messages // Intermediate messages
Token(Token), Token(Token),
// Last message // Last message
End { End {
token: Token,
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
queued: Instant, queued: Instant,
@ -330,4 +346,6 @@ pub enum InferError {
Overloaded(#[from] TryAcquireError), Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")] #[error("Input validation error: {0}")]
ValidationError(#[from] ValidationError), ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
} }

View File

@ -87,6 +87,14 @@ pub(crate) struct GeneratedText {
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize)]
pub(crate) struct StreamToken {
pub token: Token,
pub end: bool,
pub finish_reason: Option<String>,
pub generated_text: Option<String>,
}
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
pub error: String, pub error: String,

View File

@ -1,7 +1,8 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, Validation, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken,
Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
@ -10,6 +11,7 @@ use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use futures::Stream; use futures::Stream;
use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -60,6 +62,7 @@ async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
// Inference // Inference
@ -111,12 +114,12 @@ async fn generate(
); );
// Tracing metadata // Tracing metadata
tracing::Span::current().record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{:?}", total_time));
tracing::Span::current().record("validation_time", format!("{:?}", validation_time)); span.record("validation_time", format!("{:?}", validation_time));
tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); span.record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); span.record("time_per_token", format!("{:?}", time_per_token));
tracing::Span::current().record("seed", format!("{:?}", response.seed)); span.record("seed", format!("{:?}", response.seed));
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
// Send response // Send response
@ -141,57 +144,97 @@ async fn generate(
async fn generate_stream( async fn generate_stream(
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, InferError>>> { ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let span = tracing::Span::current();
let start_time = Instant::now();
let stream = async_stream::stream! { let stream = async_stream::stream! {
let start_time = Instant::now();
// Inference // Inference
let mut response_stream = infer.generate_stream(req.0).await?; let mut end_reached = false;
let mut error = false;
// Server Side Event stream match infer.generate_stream(req.0).await {
while let Some(response) = response_stream.next().await { Ok(mut response_stream) => {
match response { // Server Side Event stream
Ok(response) => { while let Some(response) = response_stream.next().await {
match response { match response {
// Prefill is ignored Ok(response) => {
InferStreamResponse::Prefill(_) => {} match response {
// Yield event for every new token // Prefill is ignored
InferStreamResponse::Token(token) => { InferStreamResponse::Prefill(_) => {}
yield Ok(Event::default().json_data(token).unwrap()) // Yield event for every new token
} InferStreamResponse::Token(token) => {
// End is used for timings metadata and logging // StreamToken
InferStreamResponse::End { let stream_token = StreamToken {
generated_text, token,
start, end: end_reached,
queued, finish_reason: None,
} => { generated_text: None,
// Timings };
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata yield Ok(Event::default().json_data(stream_token).unwrap())
tracing::Span::current().record("total_time", format!("{:?}", total_time)); }
tracing::Span::current() // End is used for timings metadata and logging
.record("validation_time", format!("{:?}", validation_time)); InferStreamResponse::End {
tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); token,
tracing::Span::current() generated_text,
.record("inference_time", format!("{:?}", inference_time)); start,
tracing::Span::current() queued,
.record("time_per_token", format!("{:?}", time_per_token)); } => {
tracing::info!("Output: {}", generated_text.text); // Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span
.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span
.record("inference_time", format!("{:?}", inference_time));
span
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamToken
end_reached = true;
let stream_token = StreamToken {
token,
end: end_reached,
finish_reason: Some(generated_text.finish_reason),
generated_text: Some(generated_text.text),
};
yield Ok(Event::default().json_data(stream_token).unwrap())
}
}
}
// Trace and yield error
Err(err) => {
error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
} }
} }
} }
// Trace and yield error },
Err(err) => { // Trace and yield error
tracing::error!("{}", err.to_string()); Err(err) => {
yield Err(err); error = true;
} tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
} }
} }
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
}; };
Sse::new(stream).keep_alive(KeepAlive::default()) Sse::new(stream).keep_alive(KeepAlive::default())
@ -264,13 +307,14 @@ async fn shutdown_signal() {
tracing::info!("signal received, starting graceful shutdown"); tracing::info!("signal received, starting graceful shutdown");
} }
/// Convert to Axum supported format /// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {
let status_code = match err { let status_code = match err {
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
}; };
( (
@ -281,3 +325,13 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
) )
} }
} }
impl From<InferError> for Event {
fn from(err: InferError) -> Self {
Event::default()
.json_data(ErrorResponse {
error: err.to_string(),
})
.unwrap()
}
}