mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: return streaming errors as an event formatted for openai's client
This commit is contained in:
parent
34a3bdedc3
commit
84cd8434b0
@ -1274,99 +1274,108 @@ pub(crate) async fn chat_completions(
|
|||||||
};
|
};
|
||||||
let mut response_as_tool = using_tools;
|
let mut response_as_tool = using_tools;
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
if let Ok(stream_token) = result {
|
match result {
|
||||||
let token_text = &stream_token.token.text.clone();
|
Ok(stream_token) => {
|
||||||
match state {
|
let token_text = &stream_token.token.text.clone();
|
||||||
StreamState::Buffering => {
|
match state {
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
StreamState::Buffering => {
|
||||||
buffer.push(stream_token);
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
buffer.push(stream_token);
|
||||||
let function_name = captures[1].to_string();
|
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||||
if function_name == "no_tool" {
|
let function_name = captures[1].to_string();
|
||||||
state = StreamState::BufferTrailing;
|
if function_name == "no_tool" {
|
||||||
response_as_tool = false;
|
state = StreamState::BufferTrailing;
|
||||||
buffer.clear();
|
response_as_tool = false;
|
||||||
json_buffer.clear();
|
buffer.clear();
|
||||||
} else {
|
json_buffer.clear();
|
||||||
state = StreamState::Content {
|
} else {
|
||||||
skip_close_quote: false,
|
state = StreamState::Content {
|
||||||
};
|
skip_close_quote: false,
|
||||||
// send all the buffered messages
|
};
|
||||||
for stream_token in &buffer {
|
// send all the buffered messages
|
||||||
let event = create_event_from_stream_token(
|
for stream_token in &buffer {
|
||||||
stream_token,
|
let event = create_event_from_stream_token(
|
||||||
logprobs,
|
stream_token,
|
||||||
stream_options.clone(),
|
logprobs,
|
||||||
response_as_tool,
|
stream_options.clone(),
|
||||||
system_fingerprint.clone(),
|
response_as_tool,
|
||||||
model_id.clone(),
|
system_fingerprint.clone(),
|
||||||
);
|
model_id.clone(),
|
||||||
yield Ok::<Event, Infallible>(event);
|
);
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
StreamState::BufferTrailing => {
|
||||||
StreamState::BufferTrailing => {
|
let infix_text = "\"content\":\"";
|
||||||
let infix_text = "\"content\":\"";
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
// keep capturing until we find the infix text
|
||||||
// keep capturing until we find the infix text
|
match json_buffer.find(infix_text) {
|
||||||
match json_buffer.find(infix_text) {
|
Some(content_key_index) => {
|
||||||
Some(content_key_index) => {
|
json_buffer =
|
||||||
json_buffer =
|
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
}
|
||||||
|
None => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
None => {
|
// if there is leftover text after removing the infix text, we need to send it
|
||||||
continue;
|
if !json_buffer.is_empty() {
|
||||||
|
let event = Event::default();
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
let chat_complete =
|
||||||
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||||
|
model_id.clone(),
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
Some(json_buffer.clone()),
|
||||||
|
None,
|
||||||
|
current_time,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
// cleanup the buffers
|
||||||
|
buffer.clear();
|
||||||
|
json_buffer.clear();
|
||||||
|
state = StreamState::Content {
|
||||||
|
skip_close_quote: true,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
// if there is leftover text after removing the infix text, we need to send it
|
StreamState::Content { skip_close_quote } => {
|
||||||
if !json_buffer.is_empty() {
|
if skip_close_quote && token_text.contains('"') {
|
||||||
let event = Event::default();
|
break;
|
||||||
let current_time = std::time::SystemTime::now()
|
}
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
// send the content
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
let event = create_event_from_stream_token(
|
||||||
.as_secs();
|
&stream_token,
|
||||||
let chat_complete =
|
logprobs,
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
stream_options.clone(),
|
||||||
model_id.clone(),
|
response_as_tool,
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
Some(json_buffer.clone()),
|
model_id.clone(),
|
||||||
None,
|
);
|
||||||
current_time,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
|
||||||
InferError::StreamSerializationError(e.to_string()).into()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
// cleanup the buffers
|
|
||||||
buffer.clear();
|
|
||||||
json_buffer.clear();
|
|
||||||
state = StreamState::Content {
|
|
||||||
skip_close_quote: true,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
StreamState::Content { skip_close_quote } => {
|
|
||||||
if skip_close_quote && token_text.contains('"') {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the content
|
yield Ok::<Event, Infallible>(event);
|
||||||
let event = create_event_from_stream_token(
|
}
|
||||||
&stream_token,
|
|
||||||
logprobs,
|
|
||||||
stream_options.clone(),
|
|
||||||
response_as_tool,
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
model_id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let error_event: ErrorEvent = err.into();
|
||||||
|
let event = Event::default().json_data(error_event).unwrap_or_else(|e| {
|
||||||
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
|
});
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||||
@ -2517,6 +2526,26 @@ impl From<InferError> for Event {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
pub struct ErrorWithMessage {
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
pub struct ErrorEvent {
|
||||||
|
error: ErrorWithMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InferError> for ErrorEvent {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
ErrorEvent {
|
||||||
|
error: ErrorWithMessage {
|
||||||
|
message: err.to_string(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum WebServerError {
|
pub enum WebServerError {
|
||||||
#[error("Axum error: {0}")]
|
#[error("Axum error: {0}")]
|
||||||
|
Loading…
Reference in New Issue
Block a user