Fixing the tool call id.

This commit is contained in:
Nicolas Patry 2025-03-10 17:07:22 +01:00
parent 3e731a7c2f
commit a73cd56075
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
3 changed files with 45 additions and 7 deletions

View File

@ -86,6 +86,7 @@ fn create_event_from_stream_token(
system_fingerprint: String, system_fingerprint: String,
model_id: String, model_id: String,
function_name: Option<String>, function_name: Option<String>,
id: String,
) -> CompletionType { ) -> CompletionType {
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -122,7 +123,7 @@ fn create_event_from_stream_token(
role: "assistant".to_string(), role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall { tool_calls: vec![DeltaToolCall {
index: 0, index: 0,
id: String::new(), id,
r#type: "function".to_string(), r#type: "function".to_string(),
function: Function { function: Function {
name: function_name, name: function_name,
@ -172,6 +173,7 @@ pub struct ChatState {
model_id: String, model_id: String,
fingerprint: String, fingerprint: String,
logprobs: bool, logprobs: bool,
id: String,
} }
impl ChatState { impl ChatState {
@ -181,6 +183,7 @@ impl ChatState {
fingerprint: String, fingerprint: String,
model_id: String, model_id: String,
logprobs: bool, logprobs: bool,
id: String,
) -> Self { ) -> Self {
let state = if using_tools { let state = if using_tools {
StreamState::Buffering StreamState::Buffering
@ -195,13 +198,13 @@ impl ChatState {
fingerprint, fingerprint,
model_id, model_id,
logprobs, logprobs,
id,
} }
} }
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> { pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
let mut events = vec![]; let mut events = vec![];
let token_text = &stream_token.token.text; let token_text = &stream_token.token.text;
println!("Got {token_text:?} - State {:?}", self.state);
match self.state { match self.state {
StreamState::Buffering => { StreamState::Buffering => {
self.text.push_str(token_text); self.text.push_str(token_text);
@ -218,6 +221,7 @@ impl ChatState {
self.fingerprint.clone(), self.fingerprint.clone(),
self.model_id.clone(), self.model_id.clone(),
None, None,
self.id.clone(),
); );
events.push(chat_complete); events.push(chat_complete);
@ -237,6 +241,7 @@ impl ChatState {
self.fingerprint.clone(), self.fingerprint.clone(),
self.model_id.clone(), self.model_id.clone(),
Some(call.function._name), Some(call.function._name),
self.id.clone(),
); );
events.push(chat_complete); events.push(chat_complete);
@ -261,6 +266,7 @@ impl ChatState {
self.fingerprint.clone(), self.fingerprint.clone(),
self.model_id.clone(), self.model_id.clone(),
None, None,
self.id.clone(),
); );
events.push(chat_complete); events.push(chat_complete);
} else { } else {
@ -271,8 +277,8 @@ impl ChatState {
self.fingerprint.clone(), self.fingerprint.clone(),
self.model_id.clone(), self.model_id.clone(),
None, None,
self.id.clone(),
); );
events.push(chat_complete); events.push(chat_complete);
} }
} }
@ -301,10 +307,8 @@ impl ChatState {
if text.ends_with("\"") { if text.ends_with("\"") {
text = &text[..text.len() - 1]; text = &text[..text.len() - 1];
} }
println!("Detected end of content {text:?}");
stream_token.token.text = text.to_string(); stream_token.token.text = text.to_string();
self.state = StreamState::NoToolFinish; self.state = StreamState::NoToolFinish;
println!("NNew state {:?}", self.state);
} }
} }
} }
@ -317,6 +321,7 @@ impl ChatState {
self.fingerprint.clone(), self.fingerprint.clone(),
self.model_id.clone(), self.model_id.clone(),
None, None,
self.id.clone(),
); );
events.push(chat_complete); events.push(chat_complete);
@ -329,6 +334,7 @@ impl ChatState {
self.fingerprint.clone(), self.fingerprint.clone(),
self.model_id.clone(), self.model_id.clone(),
None, None,
self.id.clone(),
); );
events.push(chat_complete); events.push(chat_complete);
@ -414,7 +420,7 @@ mod tests {
function, function,
} = &tool_calls[0]; } = &tool_calls[0];
assert_eq!(*index, 0); assert_eq!(*index, 0);
assert_eq!(id, ""); assert_eq!(id, "0");
assert_eq!(r#type, "function"); assert_eq!(r#type, "function");
(function.name.as_ref(), &function.arguments) (function.name.as_ref(), &function.arguments)
} else { } else {
@ -435,6 +441,7 @@ mod tests {
"fingerprint".to_string(), "fingerprint".to_string(),
"model_id".to_string(), "model_id".to_string(),
false, false,
"0".to_string(),
); );
let events = chat_state.push(StreamResponse { let events = chat_state.push(StreamResponse {
@ -480,6 +487,7 @@ mod tests {
"fingerprint".to_string(), "fingerprint".to_string(),
"model_id".to_string(), "model_id".to_string(),
false, false,
"0".to_string(),
); );
let events = chat_state.push(StreamResponse { let events = chat_state.push(StreamResponse {
@ -544,6 +552,7 @@ mod tests {
"fingerprint".to_string(), "fingerprint".to_string(),
"model_id".to_string(), "model_id".to_string(),
false, false,
"0".to_string(),
); );
let tokens = vec![ let tokens = vec![
@ -621,6 +630,7 @@ mod tests {
"fingerprint".to_string(), "fingerprint".to_string(),
"model_id".to_string(), "model_id".to_string(),
false, false,
"0".to_string(),
); );
let tokens = vec![ let tokens = vec![
@ -729,6 +739,7 @@ mod tests {
"fingerprint".to_string(), "fingerprint".to_string(),
"model_id".to_string(), "model_id".to_string(),
false, false,
"0".to_string(),
); );
let tokens = vec![ let tokens = vec![
@ -810,6 +821,7 @@ mod tests {
"fingerprint".to_string(), "fingerprint".to_string(),
"model_id".to_string(), "model_id".to_string(),
false, false,
"0".to_string(),
); );
let tokens = vec![ let tokens = vec![
@ -880,6 +892,7 @@ mod tests {
"fingerprint".to_string(), "fingerprint".to_string(),
"model_id".to_string(), "model_id".to_string(),
false, false,
"0".to_string(),
); );
let tokens = vec![ let tokens = vec![

View File

@ -22,6 +22,7 @@ use tokenizers::Encoding;
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
use uuid::Uuid;
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Clone)] #[derive(Clone)]
@ -995,6 +996,29 @@ impl ChatRequest {
using_tools, using_tools,
)) ))
} }
fn next_int_id(&self) -> Result<String, Box<dyn std::error::Error>>{
let mut id: usize = 0;
for message in &self.messages{
if let MessageBody::Tool{tool_calls} = &message.body {
for tool_call in tool_calls{
let new_id: usize = tool_call.id.parse()?;
id = std::cmp::max(id, new_id + 1);
}
}
}
Ok(id.to_string())
}
/// Try to have linearly increasing id
/// or resort to using Uuid if the initial
/// scheme is not understood
fn next_tool_call_id(&self) -> String{
self.next_int_id().unwrap_or_else(|_|{
let uid = Uuid::new_v4().to_string();
uid.to_string()
})
}
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)] #[derive(Clone, Deserialize, ToSchema, Serialize, Default)]

View File

@ -1164,6 +1164,7 @@ pub(crate) async fn chat_completions(
} = chat.clone(); } = chat.clone();
tracing::debug!("Got chat_template {:?}", infer.chat_template); tracing::debug!("Got chat_template {:?}", infer.chat_template);
let id = chat.next_tool_call_id();
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?; chat.try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters)); span.record("parameters", format!("{:?}", generate_request.parameters));
@ -1182,7 +1183,7 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs); let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id);
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result{ match result{
Ok(stream_token) => { Ok(stream_token) => {