mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix: adjust streaming tool response
This commit is contained in:
parent
5f030140be
commit
dbce04e4d3
@ -765,6 +765,7 @@ impl ChatCompletionChunk {
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
finish_reason: Option<String>,
|
||||
usage: Option<Usage>,
|
||||
tool_name: Option<String>,
|
||||
) -> Self {
|
||||
let delta = match (delta, tool_calls) {
|
||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||
@ -779,7 +780,7 @@ impl ChatCompletionChunk {
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
name: tool_name,
|
||||
arguments: tool_calls[0].to_string(),
|
||||
},
|
||||
}],
|
||||
@ -1364,7 +1365,7 @@ pub struct SimpleToken {
|
||||
stop: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||
#[derive(Debug, Serialize, ToSchema, Clone, PartialEq)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
#[schema(example = "Length")]
|
||||
pub enum FinishReason {
|
||||
|
@ -1128,6 +1128,7 @@ fn create_event_from_stream_token(
|
||||
inner_using_tools: bool,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
tool_name: Option<String>,
|
||||
) -> Event {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
@ -1141,7 +1142,9 @@ fn create_event_from_stream_token(
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
// escape the token text so its a json string
|
||||
let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#);
|
||||
(None, Some(vec![escaped_text]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
@ -1184,6 +1187,7 @@ fn create_event_from_stream_token(
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
tool_name,
|
||||
));
|
||||
|
||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
@ -1283,6 +1287,7 @@ pub(crate) async fn chat_completions(
|
||||
}
|
||||
};
|
||||
let mut response_as_tool = using_tools;
|
||||
let mut global_function_name = String::new();
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result{
|
||||
Ok(stream_token) => {
|
||||
@ -1292,8 +1297,8 @@ pub(crate) async fn chat_completions(
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
global_function_name = captures[1].to_string();
|
||||
if global_function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
@ -1302,18 +1307,7 @@ pub(crate) async fn chat_completions(
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
buffer = buffer.drain(0..1).collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1348,6 +1342,7 @@ pub(crate) async fn chat_completions(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(global_function_name.clone()),
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
@ -1365,17 +1360,22 @@ pub(crate) async fn chat_completions(
|
||||
break;
|
||||
}
|
||||
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
buffer.push(stream_token);
|
||||
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
||||
for stream_token in &buffer[..buffer.len() - 2] {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
Some(global_function_name.clone()),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
buffer = buffer.drain(buffer.len() - 2..).collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user