mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 20:42:06 +00:00
Fixing the tool call id.
This commit is contained in:
parent
3e731a7c2f
commit
a73cd56075
@ -86,6 +86,7 @@ fn create_event_from_stream_token(
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
function_name: Option<String>,
|
||||
id: String,
|
||||
) -> CompletionType {
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
@ -122,7 +123,7 @@ fn create_event_from_stream_token(
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: function_name,
|
||||
@ -172,6 +173,7 @@ pub struct ChatState {
|
||||
model_id: String,
|
||||
fingerprint: String,
|
||||
logprobs: bool,
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl ChatState {
|
||||
@ -181,6 +183,7 @@ impl ChatState {
|
||||
fingerprint: String,
|
||||
model_id: String,
|
||||
logprobs: bool,
|
||||
id: String,
|
||||
) -> Self {
|
||||
let state = if using_tools {
|
||||
StreamState::Buffering
|
||||
@ -195,13 +198,13 @@ impl ChatState {
|
||||
fingerprint,
|
||||
model_id,
|
||||
logprobs,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
|
||||
let mut events = vec![];
|
||||
let token_text = &stream_token.token.text;
|
||||
println!("Got {token_text:?} - State {:?}", self.state);
|
||||
match self.state {
|
||||
StreamState::Buffering => {
|
||||
self.text.push_str(token_text);
|
||||
@ -218,6 +221,7 @@ impl ChatState {
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
|
||||
events.push(chat_complete);
|
||||
@ -237,6 +241,7 @@ impl ChatState {
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
Some(call.function._name),
|
||||
self.id.clone(),
|
||||
);
|
||||
|
||||
events.push(chat_complete);
|
||||
@ -261,6 +266,7 @@ impl ChatState {
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
events.push(chat_complete);
|
||||
} else {
|
||||
@ -271,8 +277,8 @@ impl ChatState {
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
|
||||
events.push(chat_complete);
|
||||
}
|
||||
}
|
||||
@ -301,10 +307,8 @@ impl ChatState {
|
||||
if text.ends_with("\"") {
|
||||
text = &text[..text.len() - 1];
|
||||
}
|
||||
println!("Detected end of content {text:?}");
|
||||
stream_token.token.text = text.to_string();
|
||||
self.state = StreamState::NoToolFinish;
|
||||
println!("NNew state {:?}", self.state);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -317,6 +321,7 @@ impl ChatState {
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
|
||||
events.push(chat_complete);
|
||||
@ -329,6 +334,7 @@ impl ChatState {
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
|
||||
events.push(chat_complete);
|
||||
@ -414,7 +420,7 @@ mod tests {
|
||||
function,
|
||||
} = &tool_calls[0];
|
||||
assert_eq!(*index, 0);
|
||||
assert_eq!(id, "");
|
||||
assert_eq!(id, "0");
|
||||
assert_eq!(r#type, "function");
|
||||
(function.name.as_ref(), &function.arguments)
|
||||
} else {
|
||||
@ -435,6 +441,7 @@ mod tests {
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let events = chat_state.push(StreamResponse {
|
||||
@ -480,6 +487,7 @@ mod tests {
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let events = chat_state.push(StreamResponse {
|
||||
@ -544,6 +552,7 @@ mod tests {
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
@ -621,6 +630,7 @@ mod tests {
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
@ -729,6 +739,7 @@ mod tests {
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
@ -810,6 +821,7 @@ mod tests {
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
@ -880,6 +892,7 @@ mod tests {
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
|
@ -22,6 +22,7 @@ use tokenizers::Encoding;
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone)]
|
||||
@ -995,6 +996,29 @@ impl ChatRequest {
|
||||
using_tools,
|
||||
))
|
||||
}
|
||||
|
||||
fn next_int_id(&self) -> Result<String, Box<dyn std::error::Error>>{
|
||||
let mut id: usize = 0;
|
||||
for message in &self.messages{
|
||||
if let MessageBody::Tool{tool_calls} = &message.body {
|
||||
for tool_call in tool_calls{
|
||||
let new_id: usize = tool_call.id.parse()?;
|
||||
id = std::cmp::max(id, new_id + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(id.to_string())
|
||||
}
|
||||
|
||||
/// Try to have linearly increasing id
|
||||
/// or resort to using Uuid if the initial
|
||||
/// scheme is not understood
|
||||
fn next_tool_call_id(&self) -> String{
|
||||
self.next_int_id().unwrap_or_else(|_|{
|
||||
let uid = Uuid::new_v4().to_string();
|
||||
uid.to_string()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
||||
|
@ -1164,6 +1164,7 @@ pub(crate) async fn chat_completions(
|
||||
} = chat.clone();
|
||||
|
||||
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
||||
let id = chat.next_tool_call_id();
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.try_into_generate(&infer)?;
|
||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||
@ -1182,7 +1183,7 @@ pub(crate) async fn chat_completions(
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs);
|
||||
let mut state = ChatState::new(using_tools, stream_options, system_fingerprint, model_id, logprobs, id);
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result{
|
||||
Ok(stream_token) => {
|
||||
|
Loading…
Reference in New Issue
Block a user