mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Reworked the implementation.
This commit is contained in:
parent
22d205aa47
commit
df72deac26
@ -14,6 +14,7 @@ use chat_template::ChatTemplate;
|
|||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use minijinja::ErrorKind;
|
use minijinja::ErrorKind;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -373,4 +374,25 @@ impl InferError {
|
|||||||
InferError::StreamSerializationError(_) => "stream_serialization_error",
|
InferError::StreamSerializationError(_) => "stream_serialization_error",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn into_openai_event(self) -> Event {
|
||||||
|
let message = self.to_string();
|
||||||
|
Event::default().json_data(OpenaiErrorEvent {
|
||||||
|
error: APIError {
|
||||||
|
message,
|
||||||
|
http_status_code: 422,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct APIError {
|
||||||
|
message: String,
|
||||||
|
http_status_code: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct OpenaiErrorEvent {
|
||||||
|
error: APIError,
|
||||||
}
|
}
|
||||||
|
@ -866,14 +866,7 @@ pub(crate) async fn completions(
|
|||||||
|
|
||||||
yield Ok(event);
|
yield Ok(event);
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => yield Ok(err.into_openai_event()),
|
||||||
let event = Event::default()
|
|
||||||
.json_data(ErrorEvent::into_api_error(err, 422))
|
|
||||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
|
|
||||||
println!("{:?}", event);
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1281,8 +1274,8 @@ pub(crate) async fn chat_completions(
|
|||||||
};
|
};
|
||||||
let mut response_as_tool = using_tools;
|
let mut response_as_tool = using_tools;
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
match result {
|
match result{
|
||||||
Ok(stream_token) => {
|
Ok(stream_tokens) => {
|
||||||
let token_text = &stream_token.token.text.clone();
|
let token_text = &stream_token.token.text.clone();
|
||||||
match state {
|
match state {
|
||||||
StreamState::Buffering => {
|
StreamState::Buffering => {
|
||||||
@ -1361,6 +1354,7 @@ pub(crate) async fn chat_completions(
|
|||||||
if skip_close_quote && token_text.contains('"') {
|
if skip_close_quote && token_text.contains('"') {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// send the content
|
// send the content
|
||||||
let event = create_event_from_stream_token(
|
let event = create_event_from_stream_token(
|
||||||
&stream_token,
|
&stream_token,
|
||||||
@ -1374,14 +1368,8 @@ pub(crate) async fn chat_completions(
|
|||||||
yield Ok::<Event, Infallible>(event);
|
yield Ok::<Event, Infallible>(event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
Err(err) => {
|
Err(err) => yield Event::from_openai(err)
|
||||||
let event = Event::default()
|
|
||||||
.json_data(ErrorEvent::into_api_error(err, 422))
|
|
||||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||||
@ -2532,28 +2520,6 @@ impl From<InferError> for Event {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
|
||||||
pub struct APIError {
|
|
||||||
message: String,
|
|
||||||
http_status_code: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
|
||||||
pub struct ErrorEvent {
|
|
||||||
error: APIError,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ErrorEvent {
|
|
||||||
fn into_api_error(err: InferError, http_status_code: usize) -> Self {
|
|
||||||
ErrorEvent {
|
|
||||||
error: APIError {
|
|
||||||
message: err.to_string(),
|
|
||||||
http_status_code,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum WebServerError {
|
pub enum WebServerError {
|
||||||
#[error("Axum error: {0}")]
|
#[error("Axum error: {0}")]
|
||||||
|
Loading…
Reference in New Issue
Block a user