mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
rust code cleanup
This commit is contained in:
parent
48d095733a
commit
122c137b56
@ -70,7 +70,7 @@ impl Client {
|
|||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
@ -84,9 +84,9 @@ impl Client {
|
|||||||
Ok((response.generations, response.batch))
|
Ok((response.generations, response.batch))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batch
|
/// Generate one token for each request in the given cached batches
|
||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
|
@ -37,9 +37,19 @@ impl ShardedClient {
|
|||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.clear_cache())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
@ -52,9 +62,9 @@ impl ShardedClient {
|
|||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batch
|
/// Generate one token for each request in the given cached batches
|
||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
@ -69,14 +79,4 @@ impl ShardedClient {
|
|||||||
let (result, _, _) = select_all(futures).await;
|
let (result, _, _) = select_all(futures).await;
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
|
||||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
|
||||||
let futures: Vec<_> = self
|
|
||||||
.clients
|
|
||||||
.iter_mut()
|
|
||||||
.map(|client| client.clear_cache())
|
|
||||||
.collect();
|
|
||||||
join_all(futures).await.into_iter().collect()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::batcher::InferError;
|
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::batcher::InferStreamResponse;
|
use crate::infer::InferError;
|
||||||
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
@ -17,7 +17,7 @@ use tokio::time::Instant;
|
|||||||
pub(crate) struct Entry {
|
pub(crate) struct Entry {
|
||||||
/// Request
|
/// Request
|
||||||
pub request: GenerateRequest,
|
pub request: GenerateRequest,
|
||||||
/// Response sender to communicate between the Batcher and the batching_task
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Number of tokens in the input
|
/// Number of tokens in the input
|
||||||
pub input_length: usize,
|
pub input_length: usize,
|
||||||
|
@ -1,43 +1,49 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
|
use crate::validation::{Validation, ValidationError};
|
||||||
|
use crate::GenerateRequest;
|
||||||
use crate::{Db, Entry, Token};
|
use crate::{Db, Entry, Token};
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
|
||||||
use axum::http::StatusCode;
|
|
||||||
use axum::Json;
|
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient};
|
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
/// Batcher
|
/// Inference struct
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Batcher {
|
pub struct Infer {
|
||||||
|
/// Validation
|
||||||
|
validation: Validation,
|
||||||
/// Request database
|
/// Request database
|
||||||
db: Db,
|
db: Db,
|
||||||
/// Shared state
|
/// Shared state
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batcher shared state
|
/// Infer shared state
|
||||||
struct Shared {
|
struct Shared {
|
||||||
|
/// Inference limit
|
||||||
|
limit_concurrent_requests: Semaphore,
|
||||||
/// Batching background Tokio task notifier
|
/// Batching background Tokio task notifier
|
||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Batcher {
|
impl Infer {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
|
validation: Validation,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Batcher shared state
|
// Infer shared state
|
||||||
let db = Db::new();
|
let db = Db::new();
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
|
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -50,21 +56,30 @@ impl Batcher {
|
|||||||
shared.clone(),
|
shared.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
Self { db, shared }
|
Self {
|
||||||
|
validation,
|
||||||
|
db,
|
||||||
|
shared,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the database and return a stream of tokens
|
/// Add a new request to the database and return a stream of InferStreamResponse
|
||||||
pub(crate) fn infer_stream(
|
pub(crate) async fn generate_stream(
|
||||||
&self,
|
&self,
|
||||||
input_length: usize,
|
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> UnboundedReceiverStream<Result<InferStreamResponse, InferError>> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
let _permit = self.shared.limit_concurrent_requests.try_acquire()?;
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
let (input_length, validated_request) = self.validation.validate(request).await?;
|
||||||
|
|
||||||
// MPSC channel to communicate with the background batching task
|
// MPSC channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
// Try to append the request to the database
|
// Try to append the request to the database
|
||||||
self.db.append(Entry {
|
self.db.append(Entry {
|
||||||
request,
|
request: validated_request,
|
||||||
response_tx,
|
response_tx,
|
||||||
input_length,
|
input_length,
|
||||||
time: Instant::now(),
|
time: Instant::now(),
|
||||||
@ -76,27 +91,34 @@ impl Batcher {
|
|||||||
self.shared.batching_task.notify_one();
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
UnboundedReceiverStream::new(response_rx)
|
Ok(UnboundedReceiverStream::new(response_rx))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn infer(
|
/// Add a new request to the database and return a InferResponse
|
||||||
|
pub(crate) async fn generate(
|
||||||
&self,
|
&self,
|
||||||
input_length: usize,
|
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
let mut stream = self.infer_stream(input_length, request);
|
// Create stream
|
||||||
|
let mut stream = self.generate_stream(request).await?;
|
||||||
|
|
||||||
|
// Return values
|
||||||
let mut result_tokens = Vec::new();
|
let mut result_tokens = Vec::new();
|
||||||
let mut result_generated_text = None;
|
let mut result_generated_text = None;
|
||||||
let mut result_start = None;
|
let mut result_start = None;
|
||||||
let mut result_queued = None;
|
let mut result_queued = None;
|
||||||
|
|
||||||
|
// Iterate on stream
|
||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response? {
|
match response? {
|
||||||
|
// Add prefill tokens
|
||||||
InferStreamResponse::Prefill(prefill_tokens) => {
|
InferStreamResponse::Prefill(prefill_tokens) => {
|
||||||
result_tokens.extend(prefill_tokens)
|
result_tokens.extend(prefill_tokens)
|
||||||
}
|
}
|
||||||
|
// Push last token
|
||||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||||
|
// Final message
|
||||||
|
// Set return values
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
@ -108,6 +130,7 @@ impl Batcher {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Unwrap is safe here
|
||||||
Ok(InferResponse {
|
Ok(InferResponse {
|
||||||
tokens: result_tokens,
|
tokens: result_tokens,
|
||||||
generated_text: result_generated_text.unwrap(),
|
generated_text: result_generated_text.unwrap(),
|
||||||
@ -134,7 +157,7 @@ async fn batching_task(
|
|||||||
|
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
// Wait for a notification from the Batcher struct
|
// Wait for a notification from the Infer struct
|
||||||
shared.batching_task.notified().await;
|
shared.batching_task.notified().await;
|
||||||
|
|
||||||
// Get the next batch from the DB
|
// Get the next batch from the DB
|
||||||
@ -185,14 +208,14 @@ async fn batching_task(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
match future.await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
send_generated(generations, entries);
|
send_generations(generations, entries);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -203,7 +226,7 @@ async fn wrap_future(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send errors to the Batcher for all `entries`
|
/// Send errors to Infer for all `entries`
|
||||||
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
entries.drain().for_each(|(_, entry)| {
|
entries.drain().for_each(|(_, entry)| {
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
@ -214,14 +237,18 @@ fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send `generated_text` to the Batcher for all `finished`
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
fn send_generated(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||||
generations.into_iter().for_each(|generation| {
|
generations.into_iter().for_each(|generation| {
|
||||||
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = entries
|
let entry = entries
|
||||||
.get(&generation.request_id)
|
.get(&generation.request_id)
|
||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
let tokens = prefill_tokens
|
let tokens = prefill_tokens
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -229,27 +256,37 @@ fn send_generated(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>
|
|||||||
.zip(prefill_tokens.texts.into_iter())
|
.zip(prefill_tokens.texts.into_iter())
|
||||||
.map(|((id, logprob), text)| Token(id, text, logprob))
|
.map(|((id, logprob), text)| Token(id, text, logprob))
|
||||||
.collect();
|
.collect();
|
||||||
|
// Send message
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send(Ok(InferStreamResponse::Prefill(tokens)))
|
.send(Ok(InferStreamResponse::Prefill(tokens)))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create last Token
|
||||||
let token = Token(
|
let token = Token(
|
||||||
generation.token_id,
|
generation.token_id,
|
||||||
generation.token_text,
|
generation.token_text,
|
||||||
generation.token_logprob,
|
generation.token_logprob,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send(Ok(InferStreamResponse::Token(token)))
|
.send(Ok(InferStreamResponse::Token(token)))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
|
// Remove entry as this is the last message
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = entries
|
let entry = entries
|
||||||
.remove(&generation.request_id)
|
.remove(&generation.request_id)
|
||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send(Ok(InferStreamResponse::End {
|
.send(Ok(InferStreamResponse::End {
|
||||||
@ -264,8 +301,11 @@ fn send_generated(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum InferStreamResponse {
|
pub(crate) enum InferStreamResponse {
|
||||||
|
// Optional first message
|
||||||
Prefill(Vec<Token>),
|
Prefill(Vec<Token>),
|
||||||
|
// Intermediate messages
|
||||||
Token(Token),
|
Token(Token),
|
||||||
|
// Last message
|
||||||
End {
|
End {
|
||||||
generated_text: GeneratedText,
|
generated_text: GeneratedText,
|
||||||
start: Instant,
|
start: Instant,
|
||||||
@ -286,18 +326,8 @@ pub(crate) struct InferResponse {
|
|||||||
pub enum InferError {
|
pub enum InferError {
|
||||||
#[error("Request failed during generation: {0}")]
|
#[error("Request failed during generation: {0}")]
|
||||||
GenerationError(String),
|
GenerationError(String),
|
||||||
}
|
#[error("Model is overloaded")]
|
||||||
|
Overloaded(#[from] TryAcquireError),
|
||||||
/// Convert to Axum supported format
|
#[error("Input validation error: {0}")]
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
ValidationError(#[from] ValidationError),
|
||||||
fn from(err: InferError) -> Self {
|
|
||||||
match err {
|
|
||||||
InferError::GenerationError(_) => (
|
|
||||||
StatusCode::FAILED_DEPENDENCY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: err.to_string(),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -1,11 +1,11 @@
|
|||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
mod batcher;
|
|
||||||
mod db;
|
mod db;
|
||||||
|
mod infer;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use batcher::Batcher;
|
|
||||||
use db::{Db, Entry};
|
use db::{Db, Entry};
|
||||||
|
use infer::Infer;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
@ -1,75 +1,52 @@
|
|||||||
use crate::batcher::InferStreamResponse;
|
/// HTTP Server logic
|
||||||
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
|
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{BoxError, Json, Router};
|
use axum::{Json, Router};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::Semaphore;
|
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
// Server shared state
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct ServerState {
|
|
||||||
validation: Validation,
|
|
||||||
batcher: Batcher,
|
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(infer))]
|
||||||
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead if check if the gRPC channels are still healthy.
|
// What we should do instead if check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
|
||||||
(
|
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Model is overloaded".to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Send a small inference request
|
// Send a small inference request
|
||||||
state
|
infer
|
||||||
.batcher
|
.generate(GenerateRequest {
|
||||||
.infer(
|
inputs: "liveness".to_string(),
|
||||||
1,
|
parameters: GenerateParameters {
|
||||||
GenerateRequest {
|
temperature: 1.0,
|
||||||
inputs: "liveness".to_string(),
|
top_k: 0,
|
||||||
parameters: GenerateParameters {
|
top_p: 1.0,
|
||||||
temperature: 1.0,
|
do_sample: false,
|
||||||
top_k: 0,
|
max_new_tokens: 1,
|
||||||
top_p: 1.0,
|
stop: vec![],
|
||||||
do_sample: false,
|
details: false,
|
||||||
max_new_tokens: 1,
|
seed: None,
|
||||||
stop: vec![],
|
|
||||||
details: false,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
})
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate method
|
/// Generate method
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(state),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
total_time,
|
total_time,
|
||||||
validation_time,
|
validation_time,
|
||||||
@ -80,38 +57,17 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
|
||||||
tracing::error!("Model is overloaded");
|
|
||||||
(
|
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Model is overloaded".to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Validate request
|
|
||||||
let details = req.0.parameters.details;
|
|
||||||
let (input_length, validated_request) =
|
|
||||||
state.validation.validate(req.0).await.map_err(|err| {
|
|
||||||
tracing::error!("{}", err.to_string());
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let response = state
|
let details = req.0.parameters.details;
|
||||||
.batcher
|
let response = infer.generate(req.0).await.map_err(|err| {
|
||||||
.infer(input_length, validated_request)
|
tracing::error!("{}", err.to_string());
|
||||||
.await
|
err
|
||||||
.map_err(|err| {
|
})?;
|
||||||
tracing::error!("{}", err.to_string());
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
@ -171,39 +127,68 @@ async fn generate(
|
|||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate stream method
|
||||||
|
#[instrument(
|
||||||
|
skip(infer),
|
||||||
|
fields(
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token
|
||||||
|
)
|
||||||
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
state: Extension<ServerState>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Sse<impl Stream<Item = Result<Event, BoxError>>> {
|
) -> Sse<impl Stream<Item = Result<Event, InferError>>> {
|
||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
let start_time = Instant::now();
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(| err | {
|
|
||||||
tracing::error!("Model is overloaded");
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Validate request
|
|
||||||
let (input_length, validated_request) =
|
|
||||||
state.validation.validate(req.0).await.map_err(|err| {
|
|
||||||
tracing::error!("{}", err);
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let mut response_stream = state
|
let mut response_stream = infer.generate_stream(req.0).await?;
|
||||||
.batcher
|
|
||||||
.infer_stream(input_length, validated_request);
|
|
||||||
|
|
||||||
|
// Server Side Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
if let InferStreamResponse::Token(token) = response {
|
match response {
|
||||||
yield Ok(Event::default().json_data(token).unwrap());
|
// Prefill is ignored
|
||||||
|
InferStreamResponse::Prefill(_) => {}
|
||||||
|
// Yield event for every new token
|
||||||
|
InferStreamResponse::Token(token) => {
|
||||||
|
yield Ok(Event::default().json_data(token).unwrap())
|
||||||
|
}
|
||||||
|
// End is used for timings metadata and logging
|
||||||
|
InferStreamResponse::End {
|
||||||
|
generated_text,
|
||||||
|
start,
|
||||||
|
queued,
|
||||||
|
} => {
|
||||||
|
// Timings
|
||||||
|
let total_time = start_time.elapsed();
|
||||||
|
let validation_time = queued - start_time;
|
||||||
|
let queue_time = start - queued;
|
||||||
|
let inference_time = Instant::now() - start;
|
||||||
|
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||||
|
|
||||||
|
// Tracing metadata
|
||||||
|
tracing::Span::current().record("total_time", format!("{:?}", total_time));
|
||||||
|
tracing::Span::current()
|
||||||
|
.record("validation_time", format!("{:?}", validation_time));
|
||||||
|
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
||||||
|
tracing::Span::current()
|
||||||
|
.record("inference_time", format!("{:?}", inference_time));
|
||||||
|
tracing::Span::current()
|
||||||
|
.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
|
tracing::info!("Output: {}", generated_text.text);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Trace and yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
tracing::error!("{}", err.to_string());
|
tracing::error!("{}", err.to_string());
|
||||||
yield Ok(Event::default().data(err.to_string()));
|
yield Err(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -225,13 +210,14 @@ pub async fn run(
|
|||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
) {
|
) {
|
||||||
// Create state
|
// Create state
|
||||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
let shared_state = ServerState {
|
let infer = Infer::new(
|
||||||
|
client,
|
||||||
validation,
|
validation,
|
||||||
batcher,
|
max_batch_size,
|
||||||
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
max_waiting_tokens,
|
||||||
};
|
max_concurrent_requests,
|
||||||
|
);
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
@ -240,7 +226,7 @@ pub async fn run(
|
|||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.layer(Extension(shared_state.clone()));
|
.layer(Extension(infer));
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
@ -277,3 +263,21 @@ async fn shutdown_signal() {
|
|||||||
|
|
||||||
tracing::info!("signal received, starting graceful shutdown");
|
tracing::info!("signal received, starting graceful shutdown");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert to Axum supported format
|
||||||
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
let status_code = match err {
|
||||||
|
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
|
||||||
|
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
};
|
||||||
|
|
||||||
|
(
|
||||||
|
status_code,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: err.to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
use crate::GenerateRequest;
|
||||||
use axum::http::StatusCode;
|
|
||||||
use axum::Json;
|
|
||||||
use rand::rngs::ThreadRng;
|
use rand::rngs::ThreadRng;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -172,14 +170,3 @@ pub enum ValidationError {
|
|||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
|
|
||||||
fn from(err: ValidationError) -> Self {
|
|
||||||
(
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: err.to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user