mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
feat: consolidate streaming and event creation logic and add tests for streaming generations
This commit is contained in:
parent
330f2e419f
commit
efb20054aa
@ -754,59 +754,6 @@ pub(crate) struct Function {
|
|||||||
pub arguments: String,
|
pub arguments: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
impl ChatCompletionChunk {
|
|
||||||
pub(crate) fn new(
|
|
||||||
model: String,
|
|
||||||
system_fingerprint: String,
|
|
||||||
delta: Option<String>,
|
|
||||||
tool_calls: Option<Vec<String>>,
|
|
||||||
created: u64,
|
|
||||||
logprobs: Option<ChatCompletionLogprobs>,
|
|
||||||
finish_reason: Option<String>,
|
|
||||||
usage: Option<Usage>,
|
|
||||||
tool_name: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
let delta = match (delta, tool_calls) {
|
|
||||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: delta,
|
|
||||||
..Default::default()
|
|
||||||
}),
|
|
||||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
tool_calls: vec![DeltaToolCall {
|
|
||||||
index: 0,
|
|
||||||
id: String::new(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: Function {
|
|
||||||
name: tool_name,
|
|
||||||
arguments: tool_calls[0].to_string(),
|
|
||||||
},
|
|
||||||
}],
|
|
||||||
}),
|
|
||||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: "".to_string(),
|
|
||||||
..Default::default()
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
Self {
|
|
||||||
id: String::new(),
|
|
||||||
created,
|
|
||||||
model,
|
|
||||||
system_fingerprint,
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
delta,
|
|
||||||
logprobs,
|
|
||||||
finish_reason,
|
|
||||||
}],
|
|
||||||
usage,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq, Default))]
|
#[cfg_attr(test, derive(Debug, PartialEq, Default))]
|
||||||
pub(crate) struct ChatRequest {
|
pub(crate) struct ChatRequest {
|
||||||
@ -1021,7 +968,7 @@ impl ChatRequest {
|
|||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
struct StreamOptions {
|
pub(crate) struct StreamOptions {
|
||||||
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
#[schema(example = "true")]
|
#[schema(example = "true")]
|
||||||
include_usage: bool,
|
include_usage: bool,
|
||||||
@ -1844,9 +1791,306 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_event(
|
||||||
|
token_text: &str,
|
||||||
|
model_id: &str,
|
||||||
|
system_fingerprint: &str,
|
||||||
|
tool_name: Option<&str>,
|
||||||
|
is_tool_arg: bool,
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
) -> Event {
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
// Create the delta based on direct pattern matching of parameters
|
||||||
|
let (delta, finish) = match (tool_name, is_tool_arg, finish_reason) {
|
||||||
|
(Some(name), _, _) => (
|
||||||
|
// Tool call name event
|
||||||
|
ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
tool_calls: vec![DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: Some(name.to_string()),
|
||||||
|
arguments: String::new(),
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(None, true, _) => (
|
||||||
|
// Tool call argument event
|
||||||
|
ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
tool_calls: vec![DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: token_text.to_string(),
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(None, false, reason) => (
|
||||||
|
// Regular text event
|
||||||
|
ChatCompletionDelta::Chat(TextMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: token_text.to_string(),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
reason,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create the ChatCompletionChunk with the appropriate delta
|
||||||
|
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
model: model_id.to_string(),
|
||||||
|
system_fingerprint: system_fingerprint.to_string(),
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
delta,
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: finish,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
Event::default()
|
||||||
|
.json_data(chat_complete)
|
||||||
|
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||||
|
pub struct ParseFunction {
|
||||||
|
#[serde(rename = "_name")]
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||||
|
pub struct ToolDecision {
|
||||||
|
#[serde(rename = "function")]
|
||||||
|
function: ParseFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||||
|
pub struct NoToolDecision {
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
use axum::response::sse::Event;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::convert::Infallible;
|
||||||
|
|
||||||
|
fn get_timestamp() -> u64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process stream token into events and mutates the buffers to keep track of the current state
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn process_stream_token(
|
||||||
|
token_text: String,
|
||||||
|
json_buffer: &mut String,
|
||||||
|
name_found: &mut bool,
|
||||||
|
no_tool_chosen: &mut bool,
|
||||||
|
first_quote_removed: &mut bool,
|
||||||
|
using_tools: bool,
|
||||||
|
model_id: &str,
|
||||||
|
system_fingerprint: &str,
|
||||||
|
stream_token: &StreamResponse,
|
||||||
|
stream_options: Option<&StreamOptions>,
|
||||||
|
) -> (Vec<Result<Event, Infallible>>, bool) {
|
||||||
|
let mut events = Vec::new();
|
||||||
|
let mut should_break = false;
|
||||||
|
|
||||||
|
// Get usage information
|
||||||
|
let usage = stream_token.details.as_ref().map(|d| Usage {
|
||||||
|
completion_tokens: d.generated_tokens,
|
||||||
|
prompt_tokens: d.input_length,
|
||||||
|
total_tokens: d.input_length + d.generated_tokens,
|
||||||
|
});
|
||||||
|
|
||||||
|
json_buffer.push_str(&token_text);
|
||||||
|
|
||||||
|
// Phase 1: Function name discovery
|
||||||
|
if !*name_found {
|
||||||
|
// NOTE: when tools are supplied `name_found` is false until the generated buffer contains
|
||||||
|
// a partial JSON object with $.function._name value. This name determines the type
|
||||||
|
// of events to emit. If the name is "no_tool", we'll emit the "content" field as a chat
|
||||||
|
// completion event. Otherwise, we'll emit a tool call name event followed by a tool call
|
||||||
|
// argument event. In both cases we'll buffer tokens to get the name and then reset the buffer
|
||||||
|
// to collect the arguments.
|
||||||
|
if let Ok(ParseResult {
|
||||||
|
value: ToolDecision {
|
||||||
|
function: ParseFunction { name },
|
||||||
|
},
|
||||||
|
last_value_whole,
|
||||||
|
}) = parse_partial_json(json_buffer)
|
||||||
|
{
|
||||||
|
if !last_value_whole {
|
||||||
|
return (events, should_break);
|
||||||
|
}
|
||||||
|
*name_found = true;
|
||||||
|
if name == "no_tool" {
|
||||||
|
*no_tool_chosen = true;
|
||||||
|
} else {
|
||||||
|
events.push(Ok(create_event(
|
||||||
|
&token_text,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
Some(name.as_str()),
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
events.push(Ok(create_event(
|
||||||
|
"{",
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
None,
|
||||||
|
true,
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset buffer for arguments
|
||||||
|
json_buffer.clear();
|
||||||
|
json_buffer.push('{');
|
||||||
|
}
|
||||||
|
|
||||||
|
return (events, should_break);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: Content processing
|
||||||
|
let is_complete_json = json_buffer.ends_with('}')
|
||||||
|
&& serde_json::from_str::<Value>(&json_buffer[..json_buffer.len() - 1]).is_ok();
|
||||||
|
let mut edited_token = token_text;
|
||||||
|
|
||||||
|
// Handle different flows based on context
|
||||||
|
if using_tools {
|
||||||
|
if *no_tool_chosen && !is_complete_json {
|
||||||
|
// Content-only flow
|
||||||
|
if let Ok(ParseResult {
|
||||||
|
value: _,
|
||||||
|
last_value_whole,
|
||||||
|
}) = parse_partial_json::<NoToolDecision>(json_buffer)
|
||||||
|
{
|
||||||
|
let cleaned_token = if !*first_quote_removed {
|
||||||
|
// trim start until the first quote
|
||||||
|
*first_quote_removed = true;
|
||||||
|
edited_token
|
||||||
|
.trim_start()
|
||||||
|
.strip_prefix('"')
|
||||||
|
.unwrap_or(&edited_token)
|
||||||
|
.to_string()
|
||||||
|
} else if last_value_whole {
|
||||||
|
should_break = true;
|
||||||
|
// trim end until the last quote
|
||||||
|
edited_token
|
||||||
|
.trim_end()
|
||||||
|
.strip_suffix('"')
|
||||||
|
.unwrap_or(&edited_token)
|
||||||
|
.to_string()
|
||||||
|
} else {
|
||||||
|
edited_token.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
if !cleaned_token.is_empty() {
|
||||||
|
events.push(Ok(create_event(
|
||||||
|
&cleaned_token,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Tool with arguments flow
|
||||||
|
if is_complete_json {
|
||||||
|
edited_token.truncate(edited_token.len() - 1);
|
||||||
|
should_break = true;
|
||||||
|
}
|
||||||
|
events.push(Ok(create_event(
|
||||||
|
&edited_token,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
None,
|
||||||
|
true,
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Standard chat completion flow
|
||||||
|
if let Some(details) = stream_token.details.as_ref() {
|
||||||
|
let finish_reason = details.finish_reason.format(true);
|
||||||
|
let text = if details.finish_reason == FinishReason::Length {
|
||||||
|
&edited_token
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
events.push(Ok(create_event(
|
||||||
|
text,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
Some(finish_reason),
|
||||||
|
)));
|
||||||
|
should_break = true;
|
||||||
|
} else {
|
||||||
|
events.push(Ok(create_event(
|
||||||
|
&edited_token,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit usage data when requested
|
||||||
|
if let (Some(usage_data), true) = (
|
||||||
|
usage,
|
||||||
|
stream_options.as_ref().is_some_and(|o| o.include_usage),
|
||||||
|
) {
|
||||||
|
let current_time = get_timestamp();
|
||||||
|
|
||||||
|
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
model: model_id.to_string(),
|
||||||
|
system_fingerprint: system_fingerprint.to_string(),
|
||||||
|
choices: vec![],
|
||||||
|
usage: Some(usage_data),
|
||||||
|
});
|
||||||
|
|
||||||
|
events.push(Ok(Event::default()
|
||||||
|
.json_data(chat_complete)
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
|
||||||
|
(events, should_break)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tool_streaming_tests {
|
mod tool_streaming_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
// Test json balancing and completion
|
// Test json balancing and completion
|
||||||
#[test]
|
#[test]
|
||||||
@ -1971,7 +2215,7 @@ mod tool_streaming_tests {
|
|||||||
|
|
||||||
// Function decision
|
// Function decision
|
||||||
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
|
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
|
||||||
println!("{:?}", result);
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
assert_eq!(result.unwrap().value.function.name, "no_tool");
|
assert_eq!(result.unwrap().value.function.name, "no_tool");
|
||||||
|
|
||||||
@ -1990,7 +2234,7 @@ mod tool_streaming_tests {
|
|||||||
|
|
||||||
// Function decision
|
// Function decision
|
||||||
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
|
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
|
||||||
println!("{:?}", result);
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
assert_eq!(result.unwrap().value.function.name, "no_tool");
|
assert_eq!(result.unwrap().value.function.name, "no_tool");
|
||||||
|
|
||||||
@ -2053,4 +2297,205 @@ mod tool_streaming_tests {
|
|||||||
assert_eq!(parsed.value.as_object().unwrap()["format"], "fahrenheit");
|
assert_eq!(parsed.value.as_object().unwrap()["format"], "fahrenheit");
|
||||||
assert_eq!(parsed.last_value_whole, true);
|
assert_eq!(parsed.last_value_whole, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streaming_no_tool_decision() {
|
||||||
|
let tokens_to_stream = vec![
|
||||||
|
"{\"".to_string(),
|
||||||
|
"function".to_string(),
|
||||||
|
"\":".to_string(),
|
||||||
|
" {\"".to_string(),
|
||||||
|
"_".to_string(),
|
||||||
|
"name".to_string(),
|
||||||
|
"\":".to_string(),
|
||||||
|
" \"".to_string(),
|
||||||
|
"no".to_string(),
|
||||||
|
"_tool".to_string(),
|
||||||
|
"\",".to_string(),
|
||||||
|
" \"".to_string(),
|
||||||
|
"content".to_string(),
|
||||||
|
"\":".to_string(),
|
||||||
|
" \"".to_string(),
|
||||||
|
"I".to_string(), // Event 1
|
||||||
|
" am".to_string(), // Event 2
|
||||||
|
" a".to_string(), // Event 3
|
||||||
|
" helpful".to_string(), // Event 4
|
||||||
|
" assistant".to_string(), // Event 5
|
||||||
|
"!\"".to_string(), // Event 6 (with trailing quore removed)
|
||||||
|
];
|
||||||
|
|
||||||
|
let event_to_stream = tokens_to_stream
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, token)| StreamResponse {
|
||||||
|
index: i as u32,
|
||||||
|
token: Token {
|
||||||
|
id: 0,
|
||||||
|
text: token,
|
||||||
|
logprob: 0.0,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: None,
|
||||||
|
details: None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Create a stream from our test events
|
||||||
|
let stream = futures::stream::iter(
|
||||||
|
event_to_stream
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok::<StreamResponse, Infallible>),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Initialize variables
|
||||||
|
let mut json_buffer = String::new();
|
||||||
|
let mut name_found = false;
|
||||||
|
let mut no_tool_chosen = false;
|
||||||
|
let mut first_quote_removed = false;
|
||||||
|
let mut events = Vec::new();
|
||||||
|
|
||||||
|
let using_tools = true;
|
||||||
|
let model_id = "gpt2";
|
||||||
|
let system_fingerprint = "test";
|
||||||
|
|
||||||
|
let stream_options = Some(StreamOptions {
|
||||||
|
include_usage: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Use StreamExt to get access to next() method
|
||||||
|
use futures::StreamExt;
|
||||||
|
let mut stream = Box::pin(stream);
|
||||||
|
|
||||||
|
// Process the stream asynchronously
|
||||||
|
while let Some(Ok(stream_token)) = stream.next().await {
|
||||||
|
let (new_events, should_break) = process_stream_token(
|
||||||
|
stream_token.token.text.clone(),
|
||||||
|
&mut json_buffer,
|
||||||
|
&mut name_found,
|
||||||
|
&mut no_tool_chosen,
|
||||||
|
&mut first_quote_removed,
|
||||||
|
using_tools,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
&stream_token,
|
||||||
|
stream_options.as_ref(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
events.extend(new_events);
|
||||||
|
if should_break {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Expect 6 events (the relevant tokens within content)
|
||||||
|
assert_eq!(events.len(), 6);
|
||||||
|
// "I am a helpful assistant!"
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streaming_tool_decision() {
|
||||||
|
let tokens_to_stream = vec![
|
||||||
|
"{\"".to_string(),
|
||||||
|
"function".to_string(),
|
||||||
|
"\":".to_string(),
|
||||||
|
" {\"".to_string(),
|
||||||
|
"_".to_string(),
|
||||||
|
"name".to_string(),
|
||||||
|
"\":".to_string(),
|
||||||
|
" \"".to_string(),
|
||||||
|
"get".to_string(),
|
||||||
|
"_current".to_string(),
|
||||||
|
"_weather".to_string(),
|
||||||
|
"\",".to_string(),
|
||||||
|
// Event 1 is the function name
|
||||||
|
// Event 2 is the start of the arguments "{"
|
||||||
|
" \"".to_string(), // Event 3
|
||||||
|
"location".to_string(), // Event 4
|
||||||
|
"\":".to_string(), // Event 5
|
||||||
|
" \"".to_string(), // Event 6
|
||||||
|
"San".to_string(), // Event 7
|
||||||
|
" Francisco".to_string(), // Event 8
|
||||||
|
",".to_string(), // Event 9
|
||||||
|
" CA".to_string(), // Event 10
|
||||||
|
"\",".to_string(), // Event 11
|
||||||
|
" \"".to_string(), // Event 12
|
||||||
|
"format".to_string(), // Event 13
|
||||||
|
"\":".to_string(), // Event 14
|
||||||
|
" \"".to_string(), // Event 15
|
||||||
|
"c".to_string(), // Event 16
|
||||||
|
"elsius".to_string(), // Event 17
|
||||||
|
"\"}}".to_string(), // Event 18 retained (trailing brace removed)
|
||||||
|
];
|
||||||
|
|
||||||
|
let event_to_stream = tokens_to_stream
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, token)| StreamResponse {
|
||||||
|
index: i as u32,
|
||||||
|
token: Token {
|
||||||
|
id: 0,
|
||||||
|
text: token,
|
||||||
|
logprob: 0.0,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: None,
|
||||||
|
details: None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Create a stream from our test events
|
||||||
|
let stream = futures::stream::iter(
|
||||||
|
event_to_stream
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok::<StreamResponse, Infallible>),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Initialize variables
|
||||||
|
let mut json_buffer = String::new();
|
||||||
|
let mut name_found = false;
|
||||||
|
let mut no_tool_chosen = false;
|
||||||
|
let mut first_quote_removed = false;
|
||||||
|
let mut events = Vec::new();
|
||||||
|
|
||||||
|
let using_tools = true;
|
||||||
|
let model_id = "gpt2";
|
||||||
|
let system_fingerprint = "test";
|
||||||
|
|
||||||
|
let stream_options = Some(StreamOptions {
|
||||||
|
include_usage: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut stream = Box::pin(stream);
|
||||||
|
|
||||||
|
// Process the stream asynchronously
|
||||||
|
while let Some(Ok(stream_token)) = stream.next().await {
|
||||||
|
let (new_events, should_break) = process_stream_token(
|
||||||
|
stream_token.token.text.clone(),
|
||||||
|
&mut json_buffer,
|
||||||
|
&mut name_found,
|
||||||
|
&mut no_tool_chosen,
|
||||||
|
&mut first_quote_removed,
|
||||||
|
using_tools,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
&stream_token,
|
||||||
|
stream_options.as_ref(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
events.extend(new_events);
|
||||||
|
if should_break {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for event in &events {
|
||||||
|
println!("{:?}", event);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 18);
|
||||||
|
// "{ "location": "San Francisco, CA", "format": "celsius"}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ use crate::kserve::{
|
|||||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
|
use crate::process_stream_token;
|
||||||
use crate::sagemaker::{
|
use crate::sagemaker::{
|
||||||
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
|
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
|
||||||
__path_sagemaker_compatibility,
|
__path_sagemaker_compatibility,
|
||||||
@ -13,7 +14,6 @@ use crate::sagemaker::{
|
|||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::vertex::vertex_compatibility;
|
use crate::vertex::vertex_compatibility;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::ChatTokenizeResponse;
|
||||||
use crate::{parse_partial_json, ParseResult};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
@ -1114,116 +1114,6 @@ pub(crate) async fn completions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an event based on the token text and event type parameters.
|
|
||||||
/// `token_text` - The text to include (extract from StreamResponse.token.text or str)
|
|
||||||
/// `model_id` - Model identifier string
|
|
||||||
/// `system_fingerprint` - System fingerprint string
|
|
||||||
/// `tool_name` - If provided, creates a tool call name event
|
|
||||||
/// `is_tool_arg` - If true, creates a tool call argument event
|
|
||||||
fn create_event(
|
|
||||||
token_text: &str,
|
|
||||||
model_id: &str,
|
|
||||||
system_fingerprint: &str,
|
|
||||||
tool_name: Option<&str>,
|
|
||||||
is_tool_arg: bool,
|
|
||||||
finish_reason: Option<String>,
|
|
||||||
) -> Event {
|
|
||||||
let current_time = std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs();
|
|
||||||
|
|
||||||
let chat_complete = if let Some(tool_name) = tool_name {
|
|
||||||
// Tool call name event
|
|
||||||
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
tool_calls: vec![DeltaToolCall {
|
|
||||||
index: 0,
|
|
||||||
id: String::new(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: Function {
|
|
||||||
name: Some(tool_name.to_string()),
|
|
||||||
arguments: "".to_string(),
|
|
||||||
},
|
|
||||||
}],
|
|
||||||
});
|
|
||||||
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
|
||||||
id: String::new(),
|
|
||||||
created: current_time,
|
|
||||||
model: model_id.to_string(),
|
|
||||||
system_fingerprint: system_fingerprint.to_string(),
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
delta: tool_delta,
|
|
||||||
logprobs: None,
|
|
||||||
finish_reason: None,
|
|
||||||
}],
|
|
||||||
usage: None,
|
|
||||||
})
|
|
||||||
} else if is_tool_arg {
|
|
||||||
// Tool call argument event
|
|
||||||
let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
tool_calls: vec![DeltaToolCall {
|
|
||||||
index: 0,
|
|
||||||
id: String::new(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: Function {
|
|
||||||
name: None,
|
|
||||||
arguments: token_text.to_string(),
|
|
||||||
},
|
|
||||||
}],
|
|
||||||
});
|
|
||||||
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
|
||||||
id: String::new(),
|
|
||||||
created: current_time,
|
|
||||||
model: model_id.to_string(),
|
|
||||||
system_fingerprint: system_fingerprint.to_string(),
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
delta: tool_delta,
|
|
||||||
logprobs: None,
|
|
||||||
finish_reason: None,
|
|
||||||
}],
|
|
||||||
usage: None,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// usage, finish_reason
|
|
||||||
if finish_reason.is_some() {
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
|
||||||
model_id.to_string(),
|
|
||||||
system_fingerprint.to_string(),
|
|
||||||
Some(token_text.to_string()),
|
|
||||||
None,
|
|
||||||
current_time,
|
|
||||||
None,
|
|
||||||
finish_reason,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
// Chat completion event
|
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
|
||||||
model_id.to_string(),
|
|
||||||
system_fingerprint.to_string(),
|
|
||||||
Some(token_text.to_string()),
|
|
||||||
None,
|
|
||||||
current_time,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Event::default()
|
|
||||||
.json_data(chat_complete)
|
|
||||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
@ -1292,22 +1182,6 @@ pub(crate) async fn chat_completions(
|
|||||||
.as_secs()
|
.as_secs()
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
|
||||||
pub struct Function {
|
|
||||||
#[serde(rename = "_name")]
|
|
||||||
name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
|
||||||
pub struct ToolDecision {
|
|
||||||
function: Function,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
|
||||||
pub struct NoToolDecision {
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||||
@ -1322,184 +1196,25 @@ pub(crate) async fn chat_completions(
|
|||||||
// Process stream tokens
|
// Process stream tokens
|
||||||
while let Some(Ok(stream_token)) = response_stream.next().await {
|
while let Some(Ok(stream_token)) = response_stream.next().await {
|
||||||
let token_text = stream_token.token.text.clone();
|
let token_text = stream_token.token.text.clone();
|
||||||
let mut events = Vec::new();
|
// Process stream token into a series of events and a break signal
|
||||||
let mut should_break = false;
|
let (events, should_break) = process_stream_token(
|
||||||
|
token_text,
|
||||||
// Get usage information
|
&mut json_buffer,
|
||||||
let usage = stream_token.details.as_ref().map(|d| Usage {
|
&mut name_found,
|
||||||
completion_tokens: d.generated_tokens,
|
&mut no_tool_chosen,
|
||||||
prompt_tokens: d.input_length,
|
&mut first_quote_removed,
|
||||||
total_tokens: d.input_length + d.generated_tokens,
|
using_tools,
|
||||||
});
|
|
||||||
|
|
||||||
json_buffer.push_str(&token_text);
|
|
||||||
|
|
||||||
// Phase 1: Function name discovery
|
|
||||||
if !name_found {
|
|
||||||
// NOTE: when tools are supplied `name_found` is false until the generated buffer contains
|
|
||||||
// a partial JSON object with $.function._name value. This name determines the type
|
|
||||||
// of events to emit. If the name is "no_tool", we'll emit the "content" field as a chat
|
|
||||||
// completion event. Otherwise, we'll emit a tool call name event followed by a tool call
|
|
||||||
// argument event. In both cases we'll buffer tokens to get the name and then reset the buffer
|
|
||||||
// to collect the arguments.
|
|
||||||
if let Ok(ParseResult {
|
|
||||||
value: ToolDecision {
|
|
||||||
function: Function { name },
|
|
||||||
},
|
|
||||||
last_value_whole,
|
|
||||||
}) = parse_partial_json(&json_buffer)
|
|
||||||
{
|
|
||||||
if !last_value_whole {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
name_found = true;
|
|
||||||
if name == "no_tool" {
|
|
||||||
no_tool_chosen = true;
|
|
||||||
} else {
|
|
||||||
events.push(create_event(
|
|
||||||
&token_text,
|
|
||||||
&model_id,
|
&model_id,
|
||||||
&system_fingerprint,
|
&system_fingerprint,
|
||||||
Some(name.as_str()),
|
&stream_token,
|
||||||
false,
|
stream_options.as_ref(),
|
||||||
None,
|
).await;
|
||||||
));
|
|
||||||
events.push(create_event(
|
|
||||||
"{",
|
|
||||||
&model_id,
|
|
||||||
&system_fingerprint,
|
|
||||||
None,
|
|
||||||
true,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset buffer for arguments
|
|
||||||
json_buffer.clear();
|
|
||||||
json_buffer.push('{');
|
|
||||||
}
|
|
||||||
|
|
||||||
for event in events {
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Phase 2: Content processing
|
|
||||||
let is_complete_json = json_buffer.ends_with('}')
|
|
||||||
&& serde_json::from_str::<Value>(&json_buffer[..json_buffer.len() - 1]).is_ok();
|
|
||||||
let mut edited_token = token_text;
|
|
||||||
|
|
||||||
// Handle different flows based on context
|
|
||||||
if using_tools {
|
|
||||||
if no_tool_chosen && !is_complete_json {
|
|
||||||
// Content-only flow
|
|
||||||
if let Ok(ParseResult {
|
|
||||||
value: _,
|
|
||||||
last_value_whole,
|
|
||||||
}) = parse_partial_json::<NoToolDecision>(&json_buffer)
|
|
||||||
{
|
|
||||||
let cleaned_token = if !first_quote_removed {
|
|
||||||
// trim start unil the first quote
|
|
||||||
first_quote_removed = true;
|
|
||||||
edited_token
|
|
||||||
.trim_start()
|
|
||||||
.strip_prefix('"')
|
|
||||||
.unwrap_or(&edited_token)
|
|
||||||
.to_string()
|
|
||||||
} else if last_value_whole {
|
|
||||||
should_break = true;
|
|
||||||
// trim end until the last quote
|
|
||||||
edited_token
|
|
||||||
.trim_end()
|
|
||||||
.strip_suffix('"')
|
|
||||||
.unwrap_or(&edited_token)
|
|
||||||
.to_string()
|
|
||||||
} else {
|
|
||||||
edited_token.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
if !cleaned_token.is_empty() {
|
|
||||||
events.push(create_event(
|
|
||||||
&cleaned_token,
|
|
||||||
&model_id,
|
|
||||||
&system_fingerprint,
|
|
||||||
None,
|
|
||||||
false,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Tool with arguments flow
|
|
||||||
if is_complete_json {
|
|
||||||
edited_token.truncate(edited_token.len() - 1);
|
|
||||||
should_break = true;
|
|
||||||
}
|
|
||||||
events.push(create_event(
|
|
||||||
&edited_token,
|
|
||||||
&model_id,
|
|
||||||
&system_fingerprint,
|
|
||||||
None,
|
|
||||||
true,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Standard chat completion flow
|
|
||||||
if let Some(details) = stream_token.details.as_ref() {
|
|
||||||
let finish_reason = details.finish_reason.format(true);
|
|
||||||
let text = if details.finish_reason == FinishReason::Length {
|
|
||||||
&edited_token
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
};
|
|
||||||
events.push(create_event(
|
|
||||||
text,
|
|
||||||
&model_id,
|
|
||||||
&system_fingerprint,
|
|
||||||
None,
|
|
||||||
false,
|
|
||||||
Some(finish_reason),
|
|
||||||
));
|
|
||||||
should_break = true;
|
|
||||||
} else {
|
|
||||||
events.push(create_event(
|
|
||||||
&edited_token,
|
|
||||||
&model_id,
|
|
||||||
&system_fingerprint,
|
|
||||||
None,
|
|
||||||
false,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit all collected events
|
// Emit all collected events
|
||||||
for event in events {
|
for event in events {
|
||||||
yield Ok::<Event, Infallible>(event);
|
yield event;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emit usage data when requested
|
|
||||||
if let (Some(usage_data), true) = (
|
|
||||||
usage,
|
|
||||||
stream_options.as_ref().is_some_and(|o| o.include_usage)
|
|
||||||
) {
|
|
||||||
let current_time = get_timestamp();
|
|
||||||
|
|
||||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
|
||||||
id: String::new(),
|
|
||||||
created: current_time,
|
|
||||||
model: model_id.clone(),
|
|
||||||
system_fingerprint: system_fingerprint.clone(),
|
|
||||||
choices: vec![],
|
|
||||||
usage: Some(usage_data),
|
|
||||||
});
|
|
||||||
|
|
||||||
yield Ok(Event::default()
|
|
||||||
.json_data(chat_complete)
|
|
||||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()));
|
|
||||||
}
|
|
||||||
if should_break {
|
if should_break {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user