Improved version

This commit is contained in:
OlivierDehaene 2023-01-30 10:55:54 +01:00
parent 0b34905557
commit 046801278e
6 changed files with 166 additions and 83 deletions

View File

@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde = "1.0.150"
serde = { version = "1.0.150", features = ["derive"] }

View File

@ -8,7 +8,7 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
StoppingCriteriaParameters, PrefillTokens
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -10,6 +10,7 @@ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::OwnedSemaphorePermit;
use tokio::time::Instant;
/// Database entry
@ -25,6 +26,8 @@ pub(crate) struct Entry {
pub time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
/// Permit
pub _permit: OwnedSemaphorePermit,
}
/// Request Database

View File

@ -5,7 +5,7 @@ use crate::{Db, Entry, Token};
use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient};
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
use thiserror::Error;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
@ -22,12 +22,12 @@ pub struct Infer {
db: Db,
/// Shared state
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
}
/// Infer shared state
struct Shared {
/// Inference limit
limit_concurrent_requests: Semaphore,
/// Batching background Tokio task notifier
batching_task: Notify,
}
@ -43,7 +43,6 @@ impl Infer {
// Infer shared state
let db = Db::new();
let shared = Arc::new(Shared {
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
batching_task: Notify::new(),
});
@ -56,10 +55,14 @@ impl Infer {
shared.clone(),
));
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
Self {
validation,
db,
shared,
limit_concurrent_requests: semaphore,
}
}
@ -69,7 +72,8 @@ impl Infer {
request: GenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self.shared.limit_concurrent_requests.try_acquire()?;
// This permit will live as long as Entry
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
// Validate request
let (input_length, validated_request) = self.validation.validate(request).await?;
@ -84,6 +88,7 @@ impl Infer {
input_length,
time: Instant::now(),
batch_time: None,
_permit: permit
});
// Notify the background task that we have a new entry in the database that needs
@ -113,30 +118,48 @@ impl Infer {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(prefill_tokens) => {
result_tokens.extend(prefill_tokens)
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.collect();
result_tokens = prefill_tokens;
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
// Final message
// Set return values
InferStreamResponse::End {
token,
generated_text,
start,
queued,
} => {
result_tokens.push(token);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
}
}
}
// Unwrap is safe here
Ok(InferResponse {
tokens: result_tokens,
generated_text: result_generated_text.unwrap(),
queued: result_queued.unwrap(),
start: result_start.unwrap(),
})
// Check that we received a `InferStreamResponse::End` message
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
{
Ok(InferResponse {
tokens: result_tokens,
generated_text,
queued,
start,
})
} else {
Err(InferError::IncompleteGeneration)
}
}
}
@ -210,7 +233,7 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
future: impl Future<Output=Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
match future.await {
@ -247,20 +270,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.expect("ID not found in entries. This is a bug.");
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.collect();
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(tokens)))
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
.unwrap_or(());
}
@ -271,13 +285,6 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
generation.token_logprob,
);
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
@ -290,11 +297,19 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
entry
.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.time,
start: entry.batch_time.unwrap(),
}))
.unwrap_or(());
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
}
});
}
@ -302,11 +317,12 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(Vec<Token>),
Prefill(PrefillTokens),
// Intermediate messages
Token(Token),
// Last message
End {
token: Token,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
@ -330,4 +346,6 @@ pub enum InferError {
Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")]
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
}

View File

@ -87,6 +87,14 @@ pub(crate) struct GeneratedText {
pub details: Option<Details>,
}
#[derive(Serialize)]
pub(crate) struct StreamToken {
pub token: Token,
pub end: bool,
pub finish_reason: Option<String>,
pub generated_text: Option<String>,
}
#[derive(Serialize)]
pub(crate) struct ErrorResponse {
pub error: String,

View File

@ -1,7 +1,8 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, Validation,
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken,
Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
@ -10,6 +11,7 @@ use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use futures::Stream;
use std::convert::Infallible;
use std::net::SocketAddr;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
@ -60,6 +62,7 @@ async fn generate(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now();
// Inference
@ -111,12 +114,12 @@ async fn generate(
);
// Tracing metadata
tracing::Span::current().record("total_time", format!("{:?}", total_time));
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::Span::current().record("seed", format!("{:?}", response.seed));
span.record("total_time", format!("{:?}", total_time));
span.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("seed", format!("{:?}", response.seed));
tracing::info!("Output: {}", response.generated_text.text);
// Send response
@ -141,57 +144,97 @@ async fn generate(
async fn generate_stream(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, InferError>>> {
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let span = tracing::Span::current();
let start_time = Instant::now();
let stream = async_stream::stream! {
let start_time = Instant::now();
// Inference
let mut response_stream = infer.generate_stream(req.0).await?;
let mut end_reached = false;
let mut error = false;
// Server Side Event stream
while let Some(response) = response_stream.next().await {
match response {
Ok(response) => {
match infer.generate_stream(req.0).await {
Ok(mut response_stream) => {
// Server Side Event stream
while let Some(response) = response_stream.next().await {
match response {
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
yield Ok(Event::default().json_data(token).unwrap())
}
// End is used for timings metadata and logging
InferStreamResponse::End {
generated_text,
start,
queued,
} => {
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
Ok(response) => {
match response {
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
// StreamToken
let stream_token = StreamToken {
token,
end: end_reached,
finish_reason: None,
generated_text: None,
};
// Tracing metadata
tracing::Span::current().record("total_time", format!("{:?}", total_time));
tracing::Span::current()
.record("validation_time", format!("{:?}", validation_time));
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current()
.record("inference_time", format!("{:?}", inference_time));
tracing::Span::current()
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!("Output: {}", generated_text.text);
yield Ok(Event::default().json_data(stream_token).unwrap())
}
// End is used for timings metadata and logging
InferStreamResponse::End {
token,
generated_text,
start,
queued,
} => {
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span
.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span
.record("inference_time", format!("{:?}", inference_time));
span
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamToken
end_reached = true;
let stream_token = StreamToken {
token,
end: end_reached,
finish_reason: Some(generated_text.finish_reason),
generated_text: Some(generated_text.text),
};
yield Ok(Event::default().json_data(stream_token).unwrap())
}
}
}
// Trace and yield error
Err(err) => {
error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
}
}
// Trace and yield error
Err(err) => {
tracing::error!("{}", err.to_string());
yield Err(err);
}
},
// Trace and yield error
Err(err) => {
error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
}
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
@ -264,13 +307,14 @@ async fn shutdown_signal() {
tracing::info!("signal received, starting graceful shutdown");
}
/// Convert to Axum supported format
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {
let status_code = match err {
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
};
(
@ -281,3 +325,13 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
)
}
}
impl From<InferError> for Event {
fn from(err: InferError) -> Self {
Event::default()
.json_data(ErrorResponse {
error: err.to_string(),
})
.unwrap()
}
}