From a73cd560756a5328077e40448724a7f7e3c3445e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 10 Mar 2025 17:07:22 +0100 Subject: [PATCH] Fixing the tool call id. --- router/src/chat.rs | 25 +++++++++++++++++++------ router/src/lib.rs | 24 ++++++++++++++++++++++++ router/src/server.rs | 3 ++- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/router/src/chat.rs b/router/src/chat.rs index 1a18f030..63bd53bf 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -86,6 +86,7 @@ fn create_event_from_stream_token( system_fingerprint: String, model_id: String, function_name: Option, + id: String, ) -> CompletionType { let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -122,7 +123,7 @@ fn create_event_from_stream_token( role: "assistant".to_string(), tool_calls: vec![DeltaToolCall { index: 0, - id: String::new(), + id, r#type: "function".to_string(), function: Function { name: function_name, @@ -172,6 +173,7 @@ pub struct ChatState { model_id: String, fingerprint: String, logprobs: bool, + id: String, } impl ChatState { @@ -181,6 +183,7 @@ impl ChatState { fingerprint: String, model_id: String, logprobs: bool, + id: String, ) -> Self { let state = if using_tools { StreamState::Buffering @@ -195,13 +198,13 @@ impl ChatState { fingerprint, model_id, logprobs, + id, } } pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec { let mut events = vec![]; let token_text = &stream_token.token.text; - println!("Got {token_text:?} - State {:?}", self.state); match self.state { StreamState::Buffering => { self.text.push_str(token_text); @@ -218,6 +221,7 @@ impl ChatState { self.fingerprint.clone(), self.model_id.clone(), None, + self.id.clone(), ); events.push(chat_complete); @@ -237,6 +241,7 @@ impl ChatState { self.fingerprint.clone(), self.model_id.clone(), Some(call.function._name), + self.id.clone(), ); events.push(chat_complete); @@ -261,6 +266,7 @@ impl ChatState { self.fingerprint.clone(), self.model_id.clone(), None, + self.id.clone(), ); events.push(chat_complete); } else { @@ -271,8 +277,8 @@ impl ChatState { self.fingerprint.clone(), self.model_id.clone(), None, + self.id.clone(), ); - events.push(chat_complete); } } @@ -301,10 +307,8 @@ impl ChatState { if text.ends_with("\"") { text = &text[..text.len() - 1]; } - println!("Detected end of content {text:?}"); stream_token.token.text = text.to_string(); self.state = StreamState::NoToolFinish; - println!("NNew state {:?}", self.state); } } } @@ -317,6 +321,7 @@ impl ChatState { self.fingerprint.clone(), self.model_id.clone(), None, + self.id.clone(), ); events.push(chat_complete); @@ -329,6 +334,7 @@ impl ChatState { self.fingerprint.clone(), self.model_id.clone(), None, + self.id.clone(), ); events.push(chat_complete); @@ -414,7 +420,7 @@ mod tests { function, } = &tool_calls[0]; assert_eq!(*index, 0); - assert_eq!(id, ""); + assert_eq!(id, "0"); assert_eq!(r#type, "function"); (function.name.as_ref(), &function.arguments) } else { @@ -435,6 +441,7 @@ mod tests { "fingerprint".to_string(), "model_id".to_string(), false, + "0".to_string(), ); let events = chat_state.push(StreamResponse { @@ -480,6 +487,7 @@ mod tests { "fingerprint".to_string(), "model_id".to_string(), false, + "0".to_string(), ); let events = chat_state.push(StreamResponse { @@ -544,6 +552,7 @@ mod tests { "fingerprint".to_string(), "model_id".to_string(), false, + "0".to_string(), ); let tokens = vec![ @@ -621,6 +630,7 @@ mod tests { "fingerprint".to_string(), "model_id".to_string(), false, + "0".to_string(), ); let tokens = vec![ @@ -729,6 +739,7 @@ mod tests { "fingerprint".to_string(), "model_id".to_string(), false, + "0".to_string(), ); let tokens = vec![ @@ -810,6 +821,7 @@ mod tests { "fingerprint".to_string(), "model_id".to_string(), false, + "0".to_string(), ); let tokens = vec![ @@ -880,6 +892,7 @@ mod tests { "fingerprint".to_string(), "model_id".to_string(), false, + "0".to_string(), ); let tokens = vec![ diff --git a/router/src/lib.rs b/router/src/lib.rs index 73792bab..7370477b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -22,6 +22,7 @@ use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; use validation::Validation; +use uuid::Uuid; #[allow(clippy::large_enum_variant)] #[derive(Clone)] @@ -995,6 +996,29 @@ impl ChatRequest { using_tools, )) } + + fn next_int_id(&self) -> Result>{ + let mut id: usize = 0; + for message in &self.messages{ + if let MessageBody::Tool{tool_calls} = &message.body { + for tool_call in tool_calls{ + let new_id: usize = tool_call.id.parse()?; + id = std::cmp::max(id, new_id + 1); + } + } + } + Ok(id.to_string()) + } + + /// Try to have linearly increasing id + /// or resort to using Uuid if the initial + /// scheme is not understood + fn next_tool_call_id(&self) -> String{ + self.next_int_id().unwrap_or_else(|_|{ + let uid = Uuid::new_v4().to_string(); + uid.to_string() + }) + } } #[derive(Clone, Deserialize, ToSchema, Serialize, Default)] diff --git a/router/src/server.rs b/router/src/server.rs index d68353aa..689b2f50 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1164,6 +1164,7 @@ pub(crate) async fn chat_completions( } = chat.clone(); tracing::debug!("Got chat_template {:?}", infer.chat_template); + let id = chat.next_tool_call_id(); let (generate_request, using_tools): (GenerateRequest, bool) = chat.try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); @@ -1182,7 +1183,7 @@ pub(crate) async fn chat_completions( let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); - let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs); + let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id); while let Some(result) = response_stream.next().await { match result{ Ok(stream_token) => {