mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
fix: adjust streaming tool response
This commit is contained in:
parent
5f030140be
commit
dbce04e4d3
@ -765,6 +765,7 @@ impl ChatCompletionChunk {
|
|||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
usage: Option<Usage>,
|
usage: Option<Usage>,
|
||||||
|
tool_name: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let delta = match (delta, tool_calls) {
|
let delta = match (delta, tool_calls) {
|
||||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
@ -779,7 +780,7 @@ impl ChatCompletionChunk {
|
|||||||
id: String::new(),
|
id: String::new(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: Function {
|
function: Function {
|
||||||
name: None,
|
name: tool_name,
|
||||||
arguments: tool_calls[0].to_string(),
|
arguments: tool_calls[0].to_string(),
|
||||||
},
|
},
|
||||||
}],
|
}],
|
||||||
@ -1364,7 +1365,7 @@ pub struct SimpleToken {
|
|||||||
stop: usize,
|
stop: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
#[derive(Debug, Serialize, ToSchema, Clone, PartialEq)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
#[schema(example = "Length")]
|
#[schema(example = "Length")]
|
||||||
pub enum FinishReason {
|
pub enum FinishReason {
|
||||||
|
@ -1128,6 +1128,7 @@ fn create_event_from_stream_token(
|
|||||||
inner_using_tools: bool,
|
inner_using_tools: bool,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
tool_name: Option<String>,
|
||||||
) -> Event {
|
) -> Event {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
@ -1141,7 +1142,9 @@ fn create_event_from_stream_token(
|
|||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if inner_using_tools {
|
let (content, tool_calls) = if inner_using_tools {
|
||||||
(None, Some(vec![stream_token.token.text.clone()]))
|
// escape the token text so its a json string
|
||||||
|
let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#);
|
||||||
|
(None, Some(vec![escaped_text]))
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
Some(stream_token.token.text.clone())
|
Some(stream_token.token.text.clone())
|
||||||
@ -1184,6 +1187,7 @@ fn create_event_from_stream_token(
|
|||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
usage,
|
usage,
|
||||||
|
tool_name,
|
||||||
));
|
));
|
||||||
|
|
||||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
@ -1283,6 +1287,7 @@ pub(crate) async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut response_as_tool = using_tools;
|
let mut response_as_tool = using_tools;
|
||||||
|
let mut global_function_name = String::new();
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
match result{
|
match result{
|
||||||
Ok(stream_token) => {
|
Ok(stream_token) => {
|
||||||
@ -1292,8 +1297,8 @@ pub(crate) async fn chat_completions(
|
|||||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
buffer.push(stream_token);
|
buffer.push(stream_token);
|
||||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||||
let function_name = captures[1].to_string();
|
global_function_name = captures[1].to_string();
|
||||||
if function_name == "no_tool" {
|
if global_function_name == "no_tool" {
|
||||||
state = StreamState::BufferTrailing;
|
state = StreamState::BufferTrailing;
|
||||||
response_as_tool = false;
|
response_as_tool = false;
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
@ -1302,18 +1307,7 @@ pub(crate) async fn chat_completions(
|
|||||||
state = StreamState::Content {
|
state = StreamState::Content {
|
||||||
skip_close_quote: false,
|
skip_close_quote: false,
|
||||||
};
|
};
|
||||||
// send all the buffered messages
|
buffer = buffer.drain(0..1).collect();
|
||||||
for stream_token in &buffer {
|
|
||||||
let event = create_event_from_stream_token(
|
|
||||||
stream_token,
|
|
||||||
logprobs,
|
|
||||||
stream_options.clone(),
|
|
||||||
response_as_tool,
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
model_id.clone(),
|
|
||||||
);
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1348,6 +1342,7 @@ pub(crate) async fn chat_completions(
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
Some(global_function_name.clone()),
|
||||||
));
|
));
|
||||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
InferError::StreamSerializationError(e.to_string()).into()
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
@ -1365,17 +1360,22 @@ pub(crate) async fn chat_completions(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// send the content
|
buffer.push(stream_token);
|
||||||
let event = create_event_from_stream_token(
|
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
|
||||||
&stream_token,
|
for stream_token in &buffer[..buffer.len() - 2] {
|
||||||
logprobs,
|
let event = create_event_from_stream_token(
|
||||||
stream_options.clone(),
|
stream_token,
|
||||||
response_as_tool,
|
logprobs,
|
||||||
system_fingerprint.clone(),
|
stream_options.clone(),
|
||||||
model_id.clone(),
|
response_as_tool,
|
||||||
);
|
system_fingerprint.clone(),
|
||||||
|
model_id.clone(),
|
||||||
|
Some(global_function_name.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
yield Ok::<Event, Infallible>(event);
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
|
buffer = buffer.drain(buffer.len() - 2..).collect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user