mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: process token stream before returning to client
This commit is contained in:
parent
6db3bcb700
commit
6def99d61b
@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType};
|
||||
use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::types::IntoPyDict;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
@ -452,13 +453,27 @@ async fn generate_stream(
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let span = tracing::Span::current();
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||
|
||||
let final_response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
while let Some(raw_event) = response_stream.next().await {
|
||||
match raw_event {
|
||||
Ok(stream_token) => {
|
||||
let event = Event::default();
|
||||
let event = event.json_data(stream_token).unwrap();
|
||||
yield Ok(event);
|
||||
}
|
||||
Err(_err) => {
|
||||
let event = Event::default();
|
||||
yield Ok(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let sse = Sse::new(final_response_stream).keep_alive(KeepAlive::default());
|
||||
(headers, sse)
|
||||
}
|
||||
|
||||
@ -466,9 +481,11 @@ async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
span: tracing::Span,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
) -> (
|
||||
HeaderMap,
|
||||
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||
) {
|
||||
let start_time = Instant::now();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
@ -500,12 +517,12 @@ async fn generate_stream_internal(
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
} else if req.parameters.decoder_input_details {
|
||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
@ -535,8 +552,7 @@ async fn generate_stream_internal(
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
yield Ok(stream_token);
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
@ -600,9 +616,7 @@ async fn generate_stream_internal(
|
||||
details
|
||||
};
|
||||
|
||||
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
yield Ok(stream_token);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -610,7 +624,7 @@ async fn generate_stream_internal(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -619,7 +633,7 @@ async fn generate_stream_internal(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
}
|
||||
}
|
||||
// Check if generation reached the end
|
||||
@ -628,7 +642,7 @@ async fn generate_stream_internal(
|
||||
let err = InferError::IncompleteGenerationStream;
|
||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -771,75 +785,88 @@ async fn completions(
|
||||
|
||||
// Create a future for each generate_stream_internal call.
|
||||
let generate_future = async move {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let message = match stream_token.details {
|
||||
Some(details) => {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
Completion::Final(CompletionFinal {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
None => Completion::Chunk(Chunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: String::new(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(message)
|
||||
.unwrap_or_else(|_e| Event::default())
|
||||
};
|
||||
|
||||
let (header_tx, header_rx) = oneshot::channel();
|
||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (header_map, sse) = generate_stream_internal(
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer_clone.clone(),
|
||||
compute_type_clone.clone(),
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span_clone.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let final_response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
|
||||
while let Some(stream_token) = response_stream.next().await {
|
||||
match stream_token {
|
||||
Ok(stream_token) => {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let message = match stream_token.details {
|
||||
Some(details) => {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
Completion::Final(CompletionFinal {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
None => Completion::Chunk(Chunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: String::new(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
let event = event
|
||||
.json_data(message)
|
||||
.unwrap_or_else(|_e| Event::default());
|
||||
|
||||
yield Ok(event);
|
||||
}
|
||||
Err(_err) => {
|
||||
let event = Event::default();
|
||||
yield Ok(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// send and dont wait for response
|
||||
let _ = header_tx.send(header_map);
|
||||
let _ = header_tx.send(headers);
|
||||
|
||||
// pin an emit messages to the sse_tx
|
||||
let mut sse = Box::pin(sse);
|
||||
let mut sse = Box::pin(final_response_stream);
|
||||
while let Some(event) = sse.next().await {
|
||||
if sse_tx.send(event).is_err() {
|
||||
tracing::error!("Failed to send event. Receiver dropped.");
|
||||
@ -1072,6 +1099,84 @@ async fn completions(
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamState {
|
||||
Buffering,
|
||||
BufferTrailing,
|
||||
Content { skip_close_quote: bool },
|
||||
}
|
||||
|
||||
/// Convert a StreamResponse into an Event to be sent over SSE
|
||||
fn create_event_from_stream_token(
|
||||
stream_token: &StreamResponse,
|
||||
logprobs: bool,
|
||||
stream_options: Option<StreamOptions>,
|
||||
inner_using_tools: bool,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
) -> Event {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
let (usage, finish_reason) = match &stream_token.details {
|
||||
Some(details) => {
|
||||
let usage = if stream_options
|
||||
.as_ref()
|
||||
.map(|s| s.include_usage)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
Some(Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(usage, Some(details.finish_reason.format(true)))
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
));
|
||||
|
||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -1128,90 +1233,160 @@ async fn chat_completions(
|
||||
// static values that will be returned in all cases
|
||||
let model_id = info.model_id.clone();
|
||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
let send_function_name = false; // TODO: fix to send function name
|
||||
|
||||
// switch on stream
|
||||
if stream {
|
||||
// pass this callback to the stream generation and build the required event structure
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if using_tools {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
let final_response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
let mut buffer = Vec::new();
|
||||
let mut json_buffer = String::new();
|
||||
// let mut content_buffer = String::new();
|
||||
let mut state = if using_tools {
|
||||
StreamState::Buffering
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
let (usage, finish_reason) = match stream_token.details {
|
||||
Some(details) => {
|
||||
let usage = if stream_options
|
||||
.as_ref()
|
||||
.map(|s| s.include_usage)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
Some(Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(usage, Some(details.finish_reason.format(true)))
|
||||
StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
event
|
||||
.json_data(CompletionType::ChatCompletionChunk(
|
||||
ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
),
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
let mut response_as_tool = using_tools;
|
||||
|
||||
// Regex to match any function name
|
||||
let function_regex = Regex::new(r#"\{"function":\{"_name":"([^"]+)""#).unwrap();
|
||||
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(stream_token) => {
|
||||
let token_text = &stream_token.token.text.clone();
|
||||
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "notify_error" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
|
||||
if send_function_name {
|
||||
// send a message with the the function name
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(
|
||||
ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
None,
|
||||
Some(vec![function_name.clone()]),
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
);
|
||||
|
||||
let event = event.json_data(chat_complete).unwrap();
|
||||
yield Ok(event);
|
||||
}
|
||||
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||
StreamState::BufferTrailing => {
|
||||
let infix_text = "\"error\":\"";
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
if !json_buffer.contains(infix_text) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let error_index = json_buffer.find(infix_text).unwrap();
|
||||
json_buffer =
|
||||
json_buffer[error_index + infix_text.len()..].to_string();
|
||||
|
||||
if json_buffer.is_empty() {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
Some(json_buffer.clone()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
|
||||
let event = event.json_data(chat_complete).unwrap();
|
||||
yield Ok(event);
|
||||
}
|
||||
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
}
|
||||
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_err) => {
|
||||
yield Ok::<Event, Infallible>(Event::default());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span,
|
||||
)
|
||||
.await;
|
||||
|
||||
let response_stream = response_stream.chain(futures::stream::once(async {
|
||||
Ok(Event::default().data("[DONE]"))
|
||||
}));
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
let sse = Sse::new(final_response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) =
|
||||
@ -1246,17 +1421,33 @@ async fn chat_completions(
|
||||
if let Value::Object(ref mut props) = arguments {
|
||||
props.remove("_name");
|
||||
}
|
||||
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
match name.as_str() {
|
||||
"notify_error" => {
|
||||
// parse the error message
|
||||
let error_message = arguments
|
||||
.get("error")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
InferError::ToolError(
|
||||
"No error message found in generated text".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_string();
|
||||
(None, Some(error_message))
|
||||
}
|
||||
_ => {
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user