Fixing the tool call id.

This commit is contained in:
Nicolas Patry 2025-03-10 17:07:22 +01:00
parent 3e731a7c2f
commit a73cd56075
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
3 changed files with 45 additions and 7 deletions

View File

@ -86,6 +86,7 @@ fn create_event_from_stream_token(
system_fingerprint: String,
model_id: String,
function_name: Option<String>,
id: String,
) -> CompletionType {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -122,7 +123,7 @@ fn create_event_from_stream_token(
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
id,
r#type: "function".to_string(),
function: Function {
name: function_name,
@ -172,6 +173,7 @@ pub struct ChatState {
model_id: String,
fingerprint: String,
logprobs: bool,
id: String,
}
impl ChatState {
@ -181,6 +183,7 @@ impl ChatState {
fingerprint: String,
model_id: String,
logprobs: bool,
id: String,
) -> Self {
let state = if using_tools {
StreamState::Buffering
@ -195,13 +198,13 @@ impl ChatState {
fingerprint,
model_id,
logprobs,
id,
}
}
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
let mut events = vec![];
let token_text = &stream_token.token.text;
println!("Got {token_text:?} - State {:?}", self.state);
match self.state {
StreamState::Buffering => {
self.text.push_str(token_text);
@ -218,6 +221,7 @@ impl ChatState {
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
@ -237,6 +241,7 @@ impl ChatState {
self.fingerprint.clone(),
self.model_id.clone(),
Some(call.function._name),
self.id.clone(),
);
events.push(chat_complete);
@ -261,6 +266,7 @@ impl ChatState {
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
} else {
@ -271,8 +277,8 @@ impl ChatState {
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
}
}
@ -301,10 +307,8 @@ impl ChatState {
if text.ends_with("\"") {
text = &text[..text.len() - 1];
}
println!("Detected end of content {text:?}");
stream_token.token.text = text.to_string();
self.state = StreamState::NoToolFinish;
println!("NNew state {:?}", self.state);
}
}
}
@ -317,6 +321,7 @@ impl ChatState {
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
@ -329,6 +334,7 @@ impl ChatState {
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
@ -414,7 +420,7 @@ mod tests {
function,
} = &tool_calls[0];
assert_eq!(*index, 0);
assert_eq!(id, "");
assert_eq!(id, "0");
assert_eq!(r#type, "function");
(function.name.as_ref(), &function.arguments)
} else {
@ -435,6 +441,7 @@ mod tests {
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let events = chat_state.push(StreamResponse {
@ -480,6 +487,7 @@ mod tests {
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let events = chat_state.push(StreamResponse {
@ -544,6 +552,7 @@ mod tests {
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
@ -621,6 +630,7 @@ mod tests {
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
@ -729,6 +739,7 @@ mod tests {
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
@ -810,6 +821,7 @@ mod tests {
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
@ -880,6 +892,7 @@ mod tests {
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![

View File

@ -22,6 +22,7 @@ use tokenizers::Encoding;
use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
use uuid::Uuid;
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
@ -995,6 +996,29 @@ impl ChatRequest {
using_tools,
))
}
fn next_int_id(&self) -> Result<String, Box<dyn std::error::Error>>{
let mut id: usize = 0;
for message in &self.messages{
if let MessageBody::Tool{tool_calls} = &message.body {
for tool_call in tool_calls{
let new_id: usize = tool_call.id.parse()?;
id = std::cmp::max(id, new_id + 1);
}
}
}
Ok(id.to_string())
}
/// Try to have linearly increasing id
/// or resort to using Uuid if the initial
/// scheme is not understood
fn next_tool_call_id(&self) -> String{
self.next_int_id().unwrap_or_else(|_|{
let uid = Uuid::new_v4().to_string();
uid.to_string()
})
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]

View File

@ -1164,6 +1164,7 @@ pub(crate) async fn chat_completions(
} = chat.clone();
tracing::debug!("Got chat_template {:?}", infer.chat_template);
let id = chat.next_tool_call_id();
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters));
@ -1182,7 +1183,7 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs);
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id);
while let Some(result) = response_stream.next().await {
match result{
Ok(stream_token) => {