mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 13:32:10 +00:00
Fixing the tool call id.
This commit is contained in:
parent
3e731a7c2f
commit
a73cd56075
@ -86,6 +86,7 @@ fn create_event_from_stream_token(
|
|||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
function_name: Option<String>,
|
function_name: Option<String>,
|
||||||
|
id: String,
|
||||||
) -> CompletionType {
|
) -> CompletionType {
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -122,7 +123,7 @@ fn create_event_from_stream_token(
|
|||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
tool_calls: vec![DeltaToolCall {
|
tool_calls: vec![DeltaToolCall {
|
||||||
index: 0,
|
index: 0,
|
||||||
id: String::new(),
|
id,
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: Function {
|
function: Function {
|
||||||
name: function_name,
|
name: function_name,
|
||||||
@ -172,6 +173,7 @@ pub struct ChatState {
|
|||||||
model_id: String,
|
model_id: String,
|
||||||
fingerprint: String,
|
fingerprint: String,
|
||||||
logprobs: bool,
|
logprobs: bool,
|
||||||
|
id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatState {
|
impl ChatState {
|
||||||
@ -181,6 +183,7 @@ impl ChatState {
|
|||||||
fingerprint: String,
|
fingerprint: String,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
logprobs: bool,
|
logprobs: bool,
|
||||||
|
id: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let state = if using_tools {
|
let state = if using_tools {
|
||||||
StreamState::Buffering
|
StreamState::Buffering
|
||||||
@ -195,13 +198,13 @@ impl ChatState {
|
|||||||
fingerprint,
|
fingerprint,
|
||||||
model_id,
|
model_id,
|
||||||
logprobs,
|
logprobs,
|
||||||
|
id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
|
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
|
||||||
let mut events = vec![];
|
let mut events = vec![];
|
||||||
let token_text = &stream_token.token.text;
|
let token_text = &stream_token.token.text;
|
||||||
println!("Got {token_text:?} - State {:?}", self.state);
|
|
||||||
match self.state {
|
match self.state {
|
||||||
StreamState::Buffering => {
|
StreamState::Buffering => {
|
||||||
self.text.push_str(token_text);
|
self.text.push_str(token_text);
|
||||||
@ -218,6 +221,7 @@ impl ChatState {
|
|||||||
self.fingerprint.clone(),
|
self.fingerprint.clone(),
|
||||||
self.model_id.clone(),
|
self.model_id.clone(),
|
||||||
None,
|
None,
|
||||||
|
self.id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
@ -237,6 +241,7 @@ impl ChatState {
|
|||||||
self.fingerprint.clone(),
|
self.fingerprint.clone(),
|
||||||
self.model_id.clone(),
|
self.model_id.clone(),
|
||||||
Some(call.function._name),
|
Some(call.function._name),
|
||||||
|
self.id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
@ -261,6 +266,7 @@ impl ChatState {
|
|||||||
self.fingerprint.clone(),
|
self.fingerprint.clone(),
|
||||||
self.model_id.clone(),
|
self.model_id.clone(),
|
||||||
None,
|
None,
|
||||||
|
self.id.clone(),
|
||||||
);
|
);
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
} else {
|
} else {
|
||||||
@ -271,8 +277,8 @@ impl ChatState {
|
|||||||
self.fingerprint.clone(),
|
self.fingerprint.clone(),
|
||||||
self.model_id.clone(),
|
self.model_id.clone(),
|
||||||
None,
|
None,
|
||||||
|
self.id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -301,10 +307,8 @@ impl ChatState {
|
|||||||
if text.ends_with("\"") {
|
if text.ends_with("\"") {
|
||||||
text = &text[..text.len() - 1];
|
text = &text[..text.len() - 1];
|
||||||
}
|
}
|
||||||
println!("Detected end of content {text:?}");
|
|
||||||
stream_token.token.text = text.to_string();
|
stream_token.token.text = text.to_string();
|
||||||
self.state = StreamState::NoToolFinish;
|
self.state = StreamState::NoToolFinish;
|
||||||
println!("NNew state {:?}", self.state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -317,6 +321,7 @@ impl ChatState {
|
|||||||
self.fingerprint.clone(),
|
self.fingerprint.clone(),
|
||||||
self.model_id.clone(),
|
self.model_id.clone(),
|
||||||
None,
|
None,
|
||||||
|
self.id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
@ -329,6 +334,7 @@ impl ChatState {
|
|||||||
self.fingerprint.clone(),
|
self.fingerprint.clone(),
|
||||||
self.model_id.clone(),
|
self.model_id.clone(),
|
||||||
None,
|
None,
|
||||||
|
self.id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
@ -414,7 +420,7 @@ mod tests {
|
|||||||
function,
|
function,
|
||||||
} = &tool_calls[0];
|
} = &tool_calls[0];
|
||||||
assert_eq!(*index, 0);
|
assert_eq!(*index, 0);
|
||||||
assert_eq!(id, "");
|
assert_eq!(id, "0");
|
||||||
assert_eq!(r#type, "function");
|
assert_eq!(r#type, "function");
|
||||||
(function.name.as_ref(), &function.arguments)
|
(function.name.as_ref(), &function.arguments)
|
||||||
} else {
|
} else {
|
||||||
@ -435,6 +441,7 @@ mod tests {
|
|||||||
"fingerprint".to_string(),
|
"fingerprint".to_string(),
|
||||||
"model_id".to_string(),
|
"model_id".to_string(),
|
||||||
false,
|
false,
|
||||||
|
"0".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let events = chat_state.push(StreamResponse {
|
let events = chat_state.push(StreamResponse {
|
||||||
@ -480,6 +487,7 @@ mod tests {
|
|||||||
"fingerprint".to_string(),
|
"fingerprint".to_string(),
|
||||||
"model_id".to_string(),
|
"model_id".to_string(),
|
||||||
false,
|
false,
|
||||||
|
"0".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let events = chat_state.push(StreamResponse {
|
let events = chat_state.push(StreamResponse {
|
||||||
@ -544,6 +552,7 @@ mod tests {
|
|||||||
"fingerprint".to_string(),
|
"fingerprint".to_string(),
|
||||||
"model_id".to_string(),
|
"model_id".to_string(),
|
||||||
false,
|
false,
|
||||||
|
"0".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let tokens = vec![
|
let tokens = vec![
|
||||||
@ -621,6 +630,7 @@ mod tests {
|
|||||||
"fingerprint".to_string(),
|
"fingerprint".to_string(),
|
||||||
"model_id".to_string(),
|
"model_id".to_string(),
|
||||||
false,
|
false,
|
||||||
|
"0".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let tokens = vec![
|
let tokens = vec![
|
||||||
@ -729,6 +739,7 @@ mod tests {
|
|||||||
"fingerprint".to_string(),
|
"fingerprint".to_string(),
|
||||||
"model_id".to_string(),
|
"model_id".to_string(),
|
||||||
false,
|
false,
|
||||||
|
"0".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let tokens = vec![
|
let tokens = vec![
|
||||||
@ -810,6 +821,7 @@ mod tests {
|
|||||||
"fingerprint".to_string(),
|
"fingerprint".to_string(),
|
||||||
"model_id".to_string(),
|
"model_id".to_string(),
|
||||||
false,
|
false,
|
||||||
|
"0".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let tokens = vec![
|
let tokens = vec![
|
||||||
@ -880,6 +892,7 @@ mod tests {
|
|||||||
"fingerprint".to_string(),
|
"fingerprint".to_string(),
|
||||||
"model_id".to_string(),
|
"model_id".to_string(),
|
||||||
false,
|
false,
|
||||||
|
"0".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let tokens = vec![
|
let tokens = vec![
|
||||||
|
@ -22,6 +22,7 @@ use tokenizers::Encoding;
|
|||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[allow(clippy::large_enum_variant)]
|
#[allow(clippy::large_enum_variant)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -995,6 +996,29 @@ impl ChatRequest {
|
|||||||
using_tools,
|
using_tools,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn next_int_id(&self) -> Result<String, Box<dyn std::error::Error>>{
|
||||||
|
let mut id: usize = 0;
|
||||||
|
for message in &self.messages{
|
||||||
|
if let MessageBody::Tool{tool_calls} = &message.body {
|
||||||
|
for tool_call in tool_calls{
|
||||||
|
let new_id: usize = tool_call.id.parse()?;
|
||||||
|
id = std::cmp::max(id, new_id + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(id.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to have linearly increasing id
|
||||||
|
/// or resort to using Uuid if the initial
|
||||||
|
/// scheme is not understood
|
||||||
|
fn next_tool_call_id(&self) -> String{
|
||||||
|
self.next_int_id().unwrap_or_else(|_|{
|
||||||
|
let uid = Uuid::new_v4().to_string();
|
||||||
|
uid.to_string()
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
||||||
|
@ -1164,6 +1164,7 @@ pub(crate) async fn chat_completions(
|
|||||||
} = chat.clone();
|
} = chat.clone();
|
||||||
|
|
||||||
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
||||||
|
let id = chat.next_tool_call_id();
|
||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
chat.try_into_generate(&infer)?;
|
chat.try_into_generate(&infer)?;
|
||||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||||
@ -1182,7 +1183,7 @@ pub(crate) async fn chat_completions(
|
|||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs);
|
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id);
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
match result{
|
match result{
|
||||||
Ok(stream_token) => {
|
Ok(stream_token) => {
|
||||||
|
Loading…
Reference in New Issue
Block a user