mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Reworked the implementation.
This commit is contained in:
parent
22d205aa47
commit
df72deac26
@ -14,6 +14,7 @@ use chat_template::ChatTemplate;
|
||||
use futures::future::try_join_all;
|
||||
use futures::Stream;
|
||||
use minijinja::ErrorKind;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
@ -373,4 +374,25 @@ impl InferError {
|
||||
InferError::StreamSerializationError(_) => "stream_serialization_error",
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_openai_event(self) -> Event {
|
||||
let message = self.to_string();
|
||||
Event::default().json_data(OpenaiErrorEvent {
|
||||
error: APIError {
|
||||
message,
|
||||
http_status_code: 422,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct APIError {
|
||||
message: String,
|
||||
http_status_code: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct OpenaiErrorEvent {
|
||||
error: APIError,
|
||||
}
|
||||
|
@ -866,14 +866,7 @@ pub(crate) async fn completions(
|
||||
|
||||
yield Ok(event);
|
||||
}
|
||||
Err(err) => {
|
||||
let event = Event::default()
|
||||
.json_data(ErrorEvent::into_api_error(err, 422))
|
||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
|
||||
println!("{:?}", event);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
break
|
||||
}
|
||||
Err(err) => yield Ok(err.into_openai_event()),
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -1281,107 +1274,102 @@ pub(crate) async fn chat_completions(
|
||||
};
|
||||
let mut response_as_tool = using_tools;
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(stream_token) => {
|
||||
let token_text = &stream_token.token.text.clone();
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||
StreamState::BufferTrailing => {
|
||||
let infix_text = "\"content\":\"";
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
// keep capturing until we find the infix text
|
||||
match json_buffer.find(infix_text) {
|
||||
Some(content_key_index) => {
|
||||
json_buffer =
|
||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||
}
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// if there is leftover text after removing the infix text, we need to send it
|
||||
if !json_buffer.is_empty() {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
match result{
|
||||
Ok(stream_tokens) => {
|
||||
let token_text = &stream_token.token.text.clone();
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
Some(json_buffer.clone()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
}));
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
// cleanup the buffers
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
}
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||
StreamState::BufferTrailing => {
|
||||
let infix_text = "\"content\":\"";
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
// keep capturing until we find the infix text
|
||||
match json_buffer.find(infix_text) {
|
||||
Some(content_key_index) => {
|
||||
json_buffer =
|
||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||
}
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// if there is leftover text after removing the infix text, we need to send it
|
||||
if !json_buffer.is_empty() {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
Some(json_buffer.clone()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
}));
|
||||
}
|
||||
// cleanup the buffers
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
}
|
||||
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let event = Event::default()
|
||||
.json_data(ErrorEvent::into_api_error(err, 422))
|
||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into());
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
break;
|
||||
}
|
||||
},
|
||||
Err(err) => yield Event::from_openai(err)
|
||||
}
|
||||
}
|
||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||
@ -2532,28 +2520,6 @@ impl From<InferError> for Event {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct APIError {
|
||||
message: String,
|
||||
http_status_code: usize,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct ErrorEvent {
|
||||
error: APIError,
|
||||
}
|
||||
|
||||
impl ErrorEvent {
|
||||
fn into_api_error(err: InferError, http_status_code: usize) -> Self {
|
||||
ErrorEvent {
|
||||
error: APIError {
|
||||
message: err.to_string(),
|
||||
http_status_code,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WebServerError {
|
||||
#[error("Axum error: {0}")]
|
||||
|
Loading…
Reference in New Issue
Block a user