mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Improved version
This commit is contained in:
parent
0b34905557
commit
046801278e
@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
float_eq = "1.0.1"
|
float_eq = "1.0.1"
|
||||||
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
||||||
serde = "1.0.150"
|
serde = { version = "1.0.150", features = ["derive"] }
|
||||||
|
@ -8,7 +8,7 @@ mod sharded_client;
|
|||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
|
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters,
|
StoppingCriteriaParameters, PrefillTokens
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -10,6 +10,7 @@ use text_generation_client::{
|
|||||||
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use tokio::sync::mpsc::UnboundedSender;
|
use tokio::sync::mpsc::UnboundedSender;
|
||||||
|
use tokio::sync::OwnedSemaphorePermit;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
/// Database entry
|
/// Database entry
|
||||||
@ -25,6 +26,8 @@ pub(crate) struct Entry {
|
|||||||
pub time: Instant,
|
pub time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
pub batch_time: Option<Instant>,
|
pub batch_time: Option<Instant>,
|
||||||
|
/// Permit
|
||||||
|
pub _permit: OwnedSemaphorePermit,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Database
|
/// Request Database
|
||||||
|
@ -5,7 +5,7 @@ use crate::{Db, Entry, Token};
|
|||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient};
|
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -22,12 +22,12 @@ pub struct Infer {
|
|||||||
db: Db,
|
db: Db,
|
||||||
/// Shared state
|
/// Shared state
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
|
/// Inference limit
|
||||||
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
struct Shared {
|
struct Shared {
|
||||||
/// Inference limit
|
|
||||||
limit_concurrent_requests: Semaphore,
|
|
||||||
/// Batching background Tokio task notifier
|
/// Batching background Tokio task notifier
|
||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
}
|
}
|
||||||
@ -43,7 +43,6 @@ impl Infer {
|
|||||||
// Infer shared state
|
// Infer shared state
|
||||||
let db = Db::new();
|
let db = Db::new();
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
|
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -56,10 +55,14 @@ impl Infer {
|
|||||||
shared.clone(),
|
shared.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Inference limit with a semaphore
|
||||||
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
db,
|
db,
|
||||||
shared,
|
shared,
|
||||||
|
limit_concurrent_requests: semaphore,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +72,8 @@ impl Infer {
|
|||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let _permit = self.shared.limit_concurrent_requests.try_acquire()?;
|
// This permit will live as long as Entry
|
||||||
|
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
let (input_length, validated_request) = self.validation.validate(request).await?;
|
let (input_length, validated_request) = self.validation.validate(request).await?;
|
||||||
@ -84,6 +88,7 @@ impl Infer {
|
|||||||
input_length,
|
input_length,
|
||||||
time: Instant::now(),
|
time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
|
_permit: permit
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the database that needs
|
// Notify the background task that we have a new entry in the database that needs
|
||||||
@ -113,30 +118,48 @@ impl Infer {
|
|||||||
match response? {
|
match response? {
|
||||||
// Add prefill tokens
|
// Add prefill tokens
|
||||||
InferStreamResponse::Prefill(prefill_tokens) => {
|
InferStreamResponse::Prefill(prefill_tokens) => {
|
||||||
result_tokens.extend(prefill_tokens)
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs.into_iter())
|
||||||
|
.zip(prefill_tokens.texts.into_iter())
|
||||||
|
.map(|((id, logprob), text)| Token(id, text, logprob))
|
||||||
|
.collect();
|
||||||
|
result_tokens = prefill_tokens;
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||||
// Final message
|
// Final message
|
||||||
// Set return values
|
// Set return values
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
} => {
|
} => {
|
||||||
|
result_tokens.push(token);
|
||||||
result_generated_text = Some(generated_text);
|
result_generated_text = Some(generated_text);
|
||||||
result_start = Some(start);
|
result_start = Some(start);
|
||||||
result_queued = Some(queued)
|
result_queued = Some(queued)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Unwrap is safe here
|
|
||||||
|
// Check that we received a `InferStreamResponse::End` message
|
||||||
|
if let (Some(generated_text), Some(queued), Some(start)) =
|
||||||
|
(result_generated_text, result_queued, result_start)
|
||||||
|
{
|
||||||
Ok(InferResponse {
|
Ok(InferResponse {
|
||||||
tokens: result_tokens,
|
tokens: result_tokens,
|
||||||
generated_text: result_generated_text.unwrap(),
|
generated_text,
|
||||||
queued: result_queued.unwrap(),
|
queued,
|
||||||
start: result_start.unwrap(),
|
start,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
Err(InferError::IncompleteGeneration)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,7 +233,7 @@ async fn batching_task(
|
|||||||
|
|
||||||
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
future: impl Future<Output=Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
match future.await {
|
||||||
@ -247,20 +270,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Create Token objects
|
|
||||||
// We do that here instead of in the Python code as Rust for loops are faster
|
|
||||||
let tokens = prefill_tokens
|
|
||||||
.ids
|
|
||||||
.into_iter()
|
|
||||||
.zip(prefill_tokens.logprobs.into_iter())
|
|
||||||
.zip(prefill_tokens.texts.into_iter())
|
|
||||||
.map(|((id, logprob), text)| Token(id, text, logprob))
|
|
||||||
.collect();
|
|
||||||
// Send message
|
// Send message
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send(Ok(InferStreamResponse::Prefill(tokens)))
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,13 +285,6 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
generation.token_logprob,
|
generation.token_logprob,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Send message
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Token(token)))
|
|
||||||
.unwrap_or(());
|
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
// Remove entry as this is the last message
|
// Remove entry as this is the last message
|
||||||
// We can `expect` here as the request id should always be in the entries
|
// We can `expect` here as the request id should always be in the entries
|
||||||
@ -290,11 +297,19 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send(Ok(InferStreamResponse::End {
|
.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
generated_text,
|
generated_text,
|
||||||
queued: entry.time,
|
queued: entry.time,
|
||||||
start: entry.batch_time.unwrap(),
|
start: entry.batch_time.unwrap(),
|
||||||
}))
|
}))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
|
} else {
|
||||||
|
// Send message
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Token(token)))
|
||||||
|
.unwrap_or(());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -302,11 +317,12 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum InferStreamResponse {
|
pub(crate) enum InferStreamResponse {
|
||||||
// Optional first message
|
// Optional first message
|
||||||
Prefill(Vec<Token>),
|
Prefill(PrefillTokens),
|
||||||
// Intermediate messages
|
// Intermediate messages
|
||||||
Token(Token),
|
Token(Token),
|
||||||
// Last message
|
// Last message
|
||||||
End {
|
End {
|
||||||
|
token: Token,
|
||||||
generated_text: GeneratedText,
|
generated_text: GeneratedText,
|
||||||
start: Instant,
|
start: Instant,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
@ -330,4 +346,6 @@ pub enum InferError {
|
|||||||
Overloaded(#[from] TryAcquireError),
|
Overloaded(#[from] TryAcquireError),
|
||||||
#[error("Input validation error: {0}")]
|
#[error("Input validation error: {0}")]
|
||||||
ValidationError(#[from] ValidationError),
|
ValidationError(#[from] ValidationError),
|
||||||
|
#[error("Incomplete generation")]
|
||||||
|
IncompleteGeneration,
|
||||||
}
|
}
|
||||||
|
@ -87,6 +87,14 @@ pub(crate) struct GeneratedText {
|
|||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub(crate) struct StreamToken {
|
||||||
|
pub token: Token,
|
||||||
|
pub end: bool,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
pub generated_text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct ErrorResponse {
|
pub(crate) struct ErrorResponse {
|
||||||
pub error: String,
|
pub error: String,
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, Validation,
|
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken,
|
||||||
|
Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
@ -10,6 +11,7 @@ use axum::response::IntoResponse;
|
|||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -60,6 +62,7 @@ async fn generate(
|
|||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
@ -111,12 +114,12 @@ async fn generate(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
tracing::Span::current().record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{:?}", total_time));
|
||||||
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
|
span.record("validation_time", format!("{:?}", validation_time));
|
||||||
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{:?}", queue_time));
|
||||||
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
span.record("inference_time", format!("{:?}", inference_time));
|
||||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
tracing::Span::current().record("seed", format!("{:?}", response.seed));
|
span.record("seed", format!("{:?}", response.seed));
|
||||||
tracing::info!("Output: {}", response.generated_text.text);
|
tracing::info!("Output: {}", response.generated_text.text);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
@ -141,13 +144,17 @@ async fn generate(
|
|||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Sse<impl Stream<Item = Result<Event, InferError>>> {
|
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||||
let stream = async_stream::stream! {
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let stream = async_stream::stream! {
|
||||||
// Inference
|
// Inference
|
||||||
let mut response_stream = infer.generate_stream(req.0).await?;
|
let mut end_reached = false;
|
||||||
|
let mut error = false;
|
||||||
|
|
||||||
|
match infer.generate_stream(req.0).await {
|
||||||
|
Ok(mut response_stream) => {
|
||||||
// Server Side Event stream
|
// Server Side Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
@ -157,10 +164,19 @@ async fn generate_stream(
|
|||||||
InferStreamResponse::Prefill(_) => {}
|
InferStreamResponse::Prefill(_) => {}
|
||||||
// Yield event for every new token
|
// Yield event for every new token
|
||||||
InferStreamResponse::Token(token) => {
|
InferStreamResponse::Token(token) => {
|
||||||
yield Ok(Event::default().json_data(token).unwrap())
|
// StreamToken
|
||||||
|
let stream_token = StreamToken {
|
||||||
|
token,
|
||||||
|
end: end_reached,
|
||||||
|
finish_reason: None,
|
||||||
|
generated_text: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
}
|
}
|
||||||
// End is used for timings metadata and logging
|
// End is used for timings metadata and logging
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
@ -173,25 +189,52 @@ async fn generate_stream(
|
|||||||
let time_per_token = inference_time / generated_text.generated_tokens;
|
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
tracing::Span::current().record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{:?}", total_time));
|
||||||
tracing::Span::current()
|
span
|
||||||
.record("validation_time", format!("{:?}", validation_time));
|
.record("validation_time", format!("{:?}", validation_time));
|
||||||
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{:?}", queue_time));
|
||||||
tracing::Span::current()
|
span
|
||||||
.record("inference_time", format!("{:?}", inference_time));
|
.record("inference_time", format!("{:?}", inference_time));
|
||||||
tracing::Span::current()
|
span
|
||||||
.record("time_per_token", format!("{:?}", time_per_token));
|
.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
tracing::info!("Output: {}", generated_text.text);
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
|
// StreamToken
|
||||||
|
end_reached = true;
|
||||||
|
let stream_token = StreamToken {
|
||||||
|
token,
|
||||||
|
end: end_reached,
|
||||||
|
finish_reason: Some(generated_text.finish_reason),
|
||||||
|
generated_text: Some(generated_text.text),
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Trace and yield error
|
// Trace and yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
error = true;
|
||||||
tracing::error!("{}", err.to_string());
|
tracing::error!("{}", err.to_string());
|
||||||
yield Err(err);
|
yield Ok(Event::from(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
// Trace and yield error
|
||||||
|
Err(err) => {
|
||||||
|
error = true;
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
yield Ok(Event::from(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if generation reached the end
|
||||||
|
// Skip if we already sent an error
|
||||||
|
if !end_reached && !error {
|
||||||
|
let err = InferError::IncompleteGeneration;
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
yield Ok(Event::from(err))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
@ -264,13 +307,14 @@ async fn shutdown_signal() {
|
|||||||
tracing::info!("signal received, starting graceful shutdown");
|
tracing::info!("signal received, starting graceful shutdown");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert to Axum supported format
|
/// Convert to Axum supported formats
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
let status_code = match err {
|
let status_code = match err {
|
||||||
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
|
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
|
||||||
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
@ -281,3 +325,13 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<InferError> for Event {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
Event::default()
|
||||||
|
.json_data(ErrorResponse {
|
||||||
|
error: err.to_string(),
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user