rust code cleanup

This commit is contained in:
OlivierDehaene 2023-01-28 09:31:37 +01:00
parent 48d095733a
commit 122c137b56
7 changed files with 191 additions and 170 deletions

View File

@ -70,7 +70,7 @@ impl Client {
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
@ -84,9 +84,9 @@ impl Client {
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))
} }
/// Generate one token for each request in the given cached batch /// Generate one token for each request in the given cached batches
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn decode( pub async fn decode(

View File

@ -37,9 +37,19 @@ impl ShardedClient {
Self::from_master_client(master_client).await Self::from_master_client(master_client).await
} }
/// Clear the past generations cache
pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache())
.collect();
join_all(futures).await.into_iter().collect()
}
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
@ -52,9 +62,9 @@ impl ShardedClient {
result result
} }
/// Generate one token for each request in the given cached batch /// Generate one token for each request in the given cached batches
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
pub async fn decode( pub async fn decode(
&mut self, &mut self,
@ -69,14 +79,4 @@ impl ShardedClient {
let (result, _, _) = select_all(futures).await; let (result, _, _) = select_all(futures).await;
result result
} }
/// Clear the past generations cache
pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache())
.collect();
join_all(futures).await.into_iter().collect()
}
} }

View File

@ -1,6 +1,6 @@
use crate::batcher::InferError;
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::batcher::InferStreamResponse; use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use parking_lot::Mutex; use parking_lot::Mutex;
@ -17,7 +17,7 @@ use tokio::time::Instant;
pub(crate) struct Entry { pub(crate) struct Entry {
/// Request /// Request
pub request: GenerateRequest, pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Number of tokens in the input /// Number of tokens in the input
pub input_length: usize, pub input_length: usize,

View File

@ -1,43 +1,49 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Db, Entry, Token}; use crate::{Db, Entry, Token};
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient}; use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
/// Batcher /// Inference struct
#[derive(Clone)] #[derive(Clone)]
pub struct Batcher { pub struct Infer {
/// Validation
validation: Validation,
/// Request database /// Request database
db: Db, db: Db,
/// Shared state /// Shared state
shared: Arc<Shared>, shared: Arc<Shared>,
} }
/// Batcher shared state /// Infer shared state
struct Shared { struct Shared {
/// Inference limit
limit_concurrent_requests: Semaphore,
/// Batching background Tokio task notifier /// Batching background Tokio task notifier
batching_task: Notify, batching_task: Notify,
} }
impl Batcher { impl Infer {
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
validation: Validation,
max_batch_size: usize, max_batch_size: usize,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize,
) -> Self { ) -> Self {
// Batcher shared state // Infer shared state
let db = Db::new(); let db = Db::new();
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
batching_task: Notify::new(), batching_task: Notify::new(),
}); });
@ -50,21 +56,30 @@ impl Batcher {
shared.clone(), shared.clone(),
)); ));
Self { db, shared } Self {
validation,
db,
shared,
}
} }
/// Add a new request to the database and return a stream of tokens /// Add a new request to the database and return a stream of InferStreamResponse
pub(crate) fn infer_stream( pub(crate) async fn generate_stream(
&self, &self,
input_length: usize,
request: GenerateRequest, request: GenerateRequest,
) -> UnboundedReceiverStream<Result<InferStreamResponse, InferError>> { ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self.shared.limit_concurrent_requests.try_acquire()?;
// Validate request
let (input_length, validated_request) = self.validation.validate(request).await?;
// MPSC channel to communicate with the background batching task // MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = mpsc::unbounded_channel();
// Try to append the request to the database // Try to append the request to the database
self.db.append(Entry { self.db.append(Entry {
request, request: validated_request,
response_tx, response_tx,
input_length, input_length,
time: Instant::now(), time: Instant::now(),
@ -76,27 +91,34 @@ impl Batcher {
self.shared.batching_task.notify_one(); self.shared.batching_task.notify_one();
// Return stream // Return stream
UnboundedReceiverStream::new(response_rx) Ok(UnboundedReceiverStream::new(response_rx))
} }
pub(crate) async fn infer( /// Add a new request to the database and return a InferResponse
pub(crate) async fn generate(
&self, &self,
input_length: usize,
request: GenerateRequest, request: GenerateRequest,
) -> Result<InferResponse, InferError> { ) -> Result<InferResponse, InferError> {
let mut stream = self.infer_stream(input_length, request); // Create stream
let mut stream = self.generate_stream(request).await?;
// Return values
let mut result_tokens = Vec::new(); let mut result_tokens = Vec::new();
let mut result_generated_text = None; let mut result_generated_text = None;
let mut result_start = None; let mut result_start = None;
let mut result_queued = None; let mut result_queued = None;
// Iterate on stream
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response? { match response? {
// Add prefill tokens
InferStreamResponse::Prefill(prefill_tokens) => { InferStreamResponse::Prefill(prefill_tokens) => {
result_tokens.extend(prefill_tokens) result_tokens.extend(prefill_tokens)
} }
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token), InferStreamResponse::Token(token) => result_tokens.push(token),
// Final message
// Set return values
InferStreamResponse::End { InferStreamResponse::End {
generated_text, generated_text,
start, start,
@ -108,6 +130,7 @@ impl Batcher {
} }
} }
} }
// Unwrap is safe here
Ok(InferResponse { Ok(InferResponse {
tokens: result_tokens, tokens: result_tokens,
generated_text: result_generated_text.unwrap(), generated_text: result_generated_text.unwrap(),
@ -134,7 +157,7 @@ async fn batching_task(
// Infinite loop // Infinite loop
loop { loop {
// Wait for a notification from the Batcher struct // Wait for a notification from the Infer struct
shared.batching_task.notified().await; shared.batching_task.notified().await;
// Get the next batch from the DB // Get the next batch from the DB
@ -185,14 +208,14 @@ async fn batching_task(
} }
} }
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher /// Wrap a future inside a match statement to handle errors and send the responses to Infer
async fn wrap_future( async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>, future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<Batch> {
match future.await { match future.await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
send_generated(generations, entries); send_generations(generations, entries);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -203,7 +226,7 @@ async fn wrap_future(
} }
} }
/// Send errors to the Batcher for all `entries` /// Send errors to Infer for all `entries`
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) { fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| { entries.drain().for_each(|(_, entry)| {
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
@ -214,14 +237,18 @@ fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
}); });
} }
/// Send `generated_text` to the Batcher for all `finished` /// Send one or multiple `InferStreamResponse` to Infer for all `entries`
fn send_generated(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) { fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| { generations.into_iter().for_each(|generation| {
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries let entry = entries
.get(&generation.request_id) .get(&generation.request_id)
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let tokens = prefill_tokens let tokens = prefill_tokens
.ids .ids
.into_iter() .into_iter()
@ -229,27 +256,37 @@ fn send_generated(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>
.zip(prefill_tokens.texts.into_iter()) .zip(prefill_tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob)) .map(|((id, logprob), text)| Token(id, text, logprob))
.collect(); .collect();
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry entry
.response_tx .response_tx
.send(Ok(InferStreamResponse::Prefill(tokens))) .send(Ok(InferStreamResponse::Prefill(tokens)))
.unwrap_or(()); .unwrap_or(());
} }
// Create last Token
let token = Token( let token = Token(
generation.token_id, generation.token_id,
generation.token_text, generation.token_text,
generation.token_logprob, generation.token_logprob,
); );
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry entry
.response_tx .response_tx
.send(Ok(InferStreamResponse::Token(token))) .send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(()); .unwrap_or(());
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries let entry = entries
.remove(&generation.request_id) .remove(&generation.request_id)
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry entry
.response_tx .response_tx
.send(Ok(InferStreamResponse::End { .send(Ok(InferStreamResponse::End {
@ -264,8 +301,11 @@ fn send_generated(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(Vec<Token>), Prefill(Vec<Token>),
// Intermediate messages
Token(Token), Token(Token),
// Last message
End { End {
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
@ -286,18 +326,8 @@ pub(crate) struct InferResponse {
pub enum InferError { pub enum InferError {
#[error("Request failed during generation: {0}")] #[error("Request failed during generation: {0}")]
GenerationError(String), GenerationError(String),
} #[error("Model is overloaded")]
Overloaded(#[from] TryAcquireError),
/// Convert to Axum supported format #[error("Input validation error: {0}")]
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { ValidationError(#[from] ValidationError),
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (
StatusCode::FAILED_DEPENDENCY,
Json(ErrorResponse {
error: err.to_string(),
}),
),
}
}
} }

View File

@ -1,11 +1,11 @@
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod batcher;
mod db; mod db;
mod infer;
pub mod server; pub mod server;
mod validation; mod validation;
use batcher::Batcher;
use db::{Db, Entry}; use db::{Db, Entry};
use infer::Infer;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validation::Validation; use validation::Validation;

View File

@ -1,75 +1,52 @@
use crate::batcher::InferStreamResponse; /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{BoxError, Json, Router}; use axum::{Json, Router};
use futures::Stream; use futures::Stream;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
// Server shared state
#[derive(Clone)]
struct ServerState {
validation: Validation,
batcher: Batcher,
limit_concurrent_requests: Arc<Semaphore>,
}
/// Health check method /// Health check method
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(infer))]
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check. // be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy. // What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
// Send a small inference request // Send a small inference request
state infer
.batcher .generate(GenerateRequest {
.infer( inputs: "liveness".to_string(),
1, parameters: GenerateParameters {
GenerateRequest { temperature: 1.0,
inputs: "liveness".to_string(), top_k: 0,
parameters: GenerateParameters { top_p: 1.0,
temperature: 1.0, do_sample: false,
top_k: 0, max_new_tokens: 1,
top_p: 1.0, stop: vec![],
do_sample: false, details: false,
max_new_tokens: 1, seed: None,
stop: vec![],
details: false,
seed: None,
},
}, },
) })
.await?; .await?;
Ok(()) Ok(())
} }
/// Generate method /// Generate method
#[instrument( #[instrument(
skip(state), skip(infer),
fields( fields(
total_time, total_time,
validation_time, validation_time,
@ -80,38 +57,17 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
) )
)] )]
async fn generate( async fn generate(
state: Extension<ServerState>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
tracing::error!("Model is overloaded");
(
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
// Validate request
let details = req.0.parameters.details;
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
tracing::error!("{}", err.to_string());
err
})?;
// Inference // Inference
let response = state let details = req.0.parameters.details;
.batcher let response = infer.generate(req.0).await.map_err(|err| {
.infer(input_length, validated_request) tracing::error!("{}", err.to_string());
.await err
.map_err(|err| { })?;
tracing::error!("{}", err.to_string());
err
})?;
// Token details // Token details
let details = match details { let details = match details {
@ -171,39 +127,68 @@ async fn generate(
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }
/// Generate stream method
#[instrument(
skip(infer),
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token
)
)]
async fn generate_stream( async fn generate_stream(
state: Extension<ServerState>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, BoxError>>> { ) -> Sse<impl Stream<Item = Result<Event, InferError>>> {
let stream = async_stream::stream! { let stream = async_stream::stream! {
// Limit concurrent requests by acquiring a permit from the semaphore let start_time = Instant::now();
let _permit = state.limit_concurrent_requests.try_acquire().map_err(| err | {
tracing::error!("Model is overloaded");
err
})?;
// Validate request
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
tracing::error!("{}", err);
err
})?;
// Inference // Inference
let mut response_stream = state let mut response_stream = infer.generate_stream(req.0).await?;
.batcher
.infer_stream(input_length, validated_request);
// Server Side Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
match response { match response {
Ok(response) => { Ok(response) => {
if let InferStreamResponse::Token(token) = response { match response {
yield Ok(Event::default().json_data(token).unwrap()); // Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
yield Ok(Event::default().json_data(token).unwrap())
}
// End is used for timings metadata and logging
InferStreamResponse::End {
generated_text,
start,
queued,
} => {
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
tracing::Span::current().record("total_time", format!("{:?}", total_time));
tracing::Span::current()
.record("validation_time", format!("{:?}", validation_time));
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current()
.record("inference_time", format!("{:?}", inference_time));
tracing::Span::current()
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!("Output: {}", generated_text.text);
}
} }
} }
// Trace and yield error
Err(err) => { Err(err) => {
tracing::error!("{}", err.to_string()); tracing::error!("{}", err.to_string());
yield Ok(Event::default().data(err.to_string())); yield Err(err);
} }
} }
} }
@ -225,13 +210,14 @@ pub async fn run(
addr: SocketAddr, addr: SocketAddr,
) { ) {
// Create state // Create state
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState { let infer = Infer::new(
client,
validation, validation,
batcher, max_batch_size,
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), max_waiting_tokens,
}; max_concurrent_requests,
);
// Create router // Create router
let app = Router::new() let app = Router::new()
@ -240,7 +226,7 @@ pub async fn run(
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.layer(Extension(shared_state.clone())); .layer(Extension(infer));
// Run server // Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)
@ -277,3 +263,21 @@ async fn shutdown_signal() {
tracing::info!("signal received, starting graceful shutdown"); tracing::info!("signal received, starting graceful shutdown");
} }
/// Convert to Axum supported format
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {
let status_code = match err {
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
};
(
status_code,
Json(ErrorResponse {
error: err.to_string(),
}),
)
}
}

View File

@ -1,7 +1,5 @@
/// Payload validation logic /// Payload validation logic
use crate::{ErrorResponse, GenerateRequest}; use crate::GenerateRequest;
use axum::http::StatusCode;
use axum::Json;
use rand::rngs::ThreadRng; use rand::rngs::ThreadRng;
use rand::Rng; use rand::Rng;
use thiserror::Error; use thiserror::Error;
@ -172,14 +170,3 @@ pub enum ValidationError {
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),
} }
impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: ValidationError) -> Self {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
}),
)
}
}