Fixing some corner cases.

This commit is contained in:
Nicolas Patry 2025-03-10 12:06:44 +01:00
parent 0b710f9671
commit 3e731a7c2f
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
5 changed files with 322 additions and 78 deletions

View File

@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies] [workspace.dependencies]
base64 = "0.22.0" base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] } tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.4.1", features = ["tokio"] } hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" } metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] } minijinja = { version = "2.2.0", features = ["json"] }

View File

@ -151,6 +151,7 @@ fn create_event_from_stream_token(
)) ))
} }
#[derive(Debug)]
enum StreamState { enum StreamState {
/// Before the tools was parsed /// Before the tools was parsed
Buffering, Buffering,
@ -200,6 +201,7 @@ impl ChatState {
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> { pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
let mut events = vec![]; let mut events = vec![];
let token_text = &stream_token.token.text; let token_text = &stream_token.token.text;
println!("Got {token_text:?} - State {:?}", self.state);
match self.state { match self.state {
StreamState::Buffering => { StreamState::Buffering => {
self.text.push_str(token_text); self.text.push_str(token_text);
@ -223,9 +225,9 @@ impl ChatState {
// XXX Caution, here we do not postfix the quote, so that the current output // XXX Caution, here we do not postfix the quote, so that the current output
// Is necessarily finished with quotes for us to be able to parse. // Is necessarily finished with quotes for us to be able to parse.
let partial = &self.text; let partial = &self.text;
let partial = partial.trim_end(); let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
let partial = partial.trim_end_matches(',');
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) { if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
// This can be no_tool before the content has been emitted
if call.function._name != "no_tool" { if call.function._name != "no_tool" {
stream_token.token.text = "{".to_string(); stream_token.token.text = "{".to_string();
let chat_complete = create_event_from_stream_token( let chat_complete = create_event_from_stream_token(
@ -279,30 +281,35 @@ impl ChatState {
StreamState::NoToolFinish => {} StreamState::NoToolFinish => {}
StreamState::NoTool => { StreamState::NoTool => {
self.text.push_str(token_text); self.text.push_str(token_text);
if token_text.contains("\"") || token_text.contains("}") { if token_text.contains("\"") {
let total_text = &self.text; let mut text = self
let total_text = total_text.trim_end(); .text
let total_text = total_text.trim_end_matches('}'); .trim_end_matches(|c: char| c.is_whitespace() || c == '}');
let total_text = total_text.trim_end(); // Trim once
let total_text = total_text.trim_end_matches('"'); if text.ends_with("\"") {
if let Ok(value) = // Verify we have actually trimmed something
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", total_text)) // The opposite can happen if the model is outputting inline JSON.
text = &text[..text.len() - 1];
if let Ok(_value) =
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", text))
{ {
if !value.function.content.is_empty() { let mut text = token_text
let text = token_text.trim_end(); .trim_end_matches(|c: char| c.is_whitespace() || c == '}');
let text = text.trim_end_matches('}');
let mut text = text.trim_end();
// Effectively trim_end_match('"', 1) // Effectively trim_end_match('"', 1)
// because we do not want to eventually trim finishing escaped quotes // because we do not want to eventually trim finishing escaped quotes
// {{"\"Something\""}} // {{"\"Something\""}}
if text.ends_with("\"") { if text.ends_with("\"") {
text = &text[..text.len() - 1]; text = &text[..text.len() - 1];
} }
println!("Detected end of content {text:?}");
stream_token.token.text = text.to_string(); stream_token.token.text = text.to_string();
self.state = StreamState::NoToolFinish; self.state = StreamState::NoToolFinish;
println!("NNew state {:?}", self.state);
} }
} }
} }
// This escaping is usually inline json escaping and we can therefore remove it.
stream_token.token.text = stream_token.token.text.replace("\\", "");
let chat_complete = create_event_from_stream_token( let chat_complete = create_event_from_stream_token(
&stream_token, &stream_token,
self.logprobs, self.logprobs,
@ -372,6 +379,52 @@ mod tests {
use super::*; use super::*;
fn get_text_content(event: &CompletionType) -> &String {
match event {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Chat(TextMessage { content, .. }),
..
} = &choices[0]
{
content
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
}
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
match event {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
..
} = &choices[0]
{
assert_eq!(tool_calls.len(), 1);
let DeltaToolCall {
index,
id,
r#type,
function,
} = &tool_calls[0];
assert_eq!(*index, 0);
assert_eq!(id, "");
assert_eq!(r#type, "function");
(function.name.as_ref(), &function.arguments)
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
}
#[test] #[test]
fn test_chat_stream() { fn test_chat_stream() {
let mut chat_state = ChatState::new( let mut chat_state = ChatState::new(
@ -518,6 +571,83 @@ mod tests {
"}".to_string(), "}".to_string(),
"}".to_string(), "}".to_string(),
]; ];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..14] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for token in &tokens[14..14 + 7] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1);
let content = get_text_content(&events[0]);
output.push_str(content);
}
assert_eq!(output, "I am a helpful assistant!");
// No tool finish
for token in &tokens[14 + 7..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
}
#[test]
fn test_chat_stream_tool_no_tool_many_quotes() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":".to_string(),
" \"".to_string(), // Token 14
"I".to_string(), // Event 1
" am".to_string(), // Event 2
" a".to_string(), // Event 3
" helpful".to_string(), // Event 4
" assistant".to_string(), // Event 5
"!\\\"\"".to_string(), // Extra inside the string quote that would get removed
"}".to_string(),
"}".to_string(),
];
// Initial ignored output // Initial ignored output
for text in &tokens[..14] { for text in &tokens[..14] {
@ -569,7 +699,7 @@ mod tests {
} }
} }
assert_eq!(output, "I am a helpful assistant!"); assert_eq!(output, "I am a helpful assistant!\"");
// No tool finish // No tool finish
for text in &tokens[14 + 7..] { for text in &tokens[14 + 7..] {
@ -589,6 +719,157 @@ mod tests {
} }
} }
#[test]
fn test_chat_stream_tool_no_tool_inline_json() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":".to_string(),
" \"".to_string(), // Token 14
"{\\\"".to_string(), // Event 1
"a".to_string(), // Event 1
"\\\":".to_string(), // Event 1
"2".to_string(), // Event 2
",\\".to_string(), // Event 2
"\"".to_string(), // Event 2
"b".to_string(), // Event 3
"\\\": ".to_string(), // Event 4
"1".to_string(), // Event 5
"}".to_string(), // Event 5
"\"}".to_string(), // Extra inside the string quote that would get removed
"}".to_string(),
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..14] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for token in &tokens[14..14 + 12] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1, "Current text is {output:?}");
let content = get_text_content(&events[0]);
output.push_str(content);
}
assert_eq!(output, "{\"a\":2,\"b\": 1}");
// No tool finish
for token in &tokens[14 + 12..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "Extra events {events:?}");
}
}
#[test]
fn test_chat_stream_tool_no_tool_empty() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":\"".to_string(),
"\"}".to_string(), // Token 13
"}".to_string(), // Event 1
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..13] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for token in &tokens[13..13 + 2] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1, "Current text is {output:?}");
let content = get_text_content(&events[0]);
output.push_str(content);
}
assert_eq!(output, "");
// No tool finish
for token in &tokens[13 + 2..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "Extra events {events:?}");
}
}
#[test] #[test]
fn test_chat_stream_tool_get_weather() { fn test_chat_stream_tool_get_weather() {
let mut chat_state = ChatState::new( let mut chat_state = ChatState::new(
@ -633,10 +914,9 @@ mod tests {
"elsius".to_string(), // Event 17 "elsius".to_string(), // Event 17
"\"}}".to_string(), // Event 18 retained (trailing brace removed) "\"}}".to_string(), // Event 18 retained (trailing brace removed)
]; ];
let tokens: Vec<_> = tokens
// Initial ignored output .into_iter()
for text in &tokens[..11] { .map(|text| StreamResponse {
let events = chat_state.push(StreamResponse {
generated_text: None, generated_text: None,
token: Token { token: Token {
id: 42, id: 42,
@ -647,56 +927,27 @@ mod tests {
top_tokens: vec![], top_tokens: vec![],
index: 0, index: 0,
details: None, details: None,
}); })
.collect();
// Initial ignored output
for token in &tokens[..11] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "{events:?}"); assert_eq!(events.len(), 0, "{events:?}");
} }
// No tool output // No tool output
let mut output = String::new(); let mut output = String::new();
let mut output_name = String::new(); let mut output_name = String::new();
for text in &tokens[11..11 + 17] { for token in &tokens[11..11 + 17] {
let events = chat_state.push(StreamResponse { let events = chat_state.push(token.clone());
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
match &events[0] { let (name, arguments) = get_tool_call_content(&events[0]);
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => { if let Some(name) = name {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
..
} = &choices[0]
{
assert_eq!(tool_calls.len(), 1);
let DeltaToolCall {
index,
id,
r#type,
function,
} = &tool_calls[0];
assert_eq!(*index, 0);
assert_eq!(id, "");
assert_eq!(r#type, "function");
if let Some(name) = &function.name {
assert_eq!(name, "get_current_weather"); assert_eq!(name, "get_current_weather");
output_name.push_str(&name); output_name.push_str(&name);
} }
output.push_str(&function.arguments); output.push_str(arguments);
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
} }
assert_eq!(output_name, "get_current_weather"); assert_eq!(output_name, "get_current_weather");
@ -706,19 +957,8 @@ mod tests {
); );
// No tool finish // No tool finish
for text in &tokens[11 + 17..] { for token in &tokens[11 + 17..] {
let events = chat_state.push(StreamResponse { let events = chat_state.push(token.clone());
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
assert_eq!(events.len(), 0); assert_eq!(events.len(), 0);
} }
} }

View File

@ -16,7 +16,7 @@ pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Erro
Ok(Local::now().format(&format_str).to_string()) Ok(Local::now().format(&format_str).to_string())
} }
#[derive(Clone)] #[derive(Debug, Clone)]
pub(crate) struct ChatTemplate { pub(crate) struct ChatTemplate {
template: Template<'static, 'static>, template: Template<'static, 'static>,
bos_token: Option<String>, bos_token: Option<String>,

View File

@ -52,7 +52,7 @@ pub struct Infer {
/// Request backend /// Request backend
backend: Arc<dyn Backend + Send + Sync>, backend: Arc<dyn Backend + Send + Sync>,
/// Chat template /// Chat template
chat_template: Option<ChatTemplate>, pub(crate) chat_template: Option<ChatTemplate>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Backend health /// Backend health

View File

@ -1162,6 +1162,8 @@ pub(crate) async fn chat_completions(
logprobs, logprobs,
.. ..
} = chat.clone(); } = chat.clone();
tracing::debug!("Got chat_template {:?}", infer.chat_template);
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?; chat.try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters)); span.record("parameters", format!("{:?}", generate_request.parameters));
@ -1565,6 +1567,7 @@ pub async fn run(
) )
} }
Type::Cache(cache) => { Type::Cache(cache) => {
tracing::info!("Cache {cache:?}");
let repo = cache.repo(Repo::with_revision( let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(), tokenizer_name.to_string(),
RepoType::Model, RepoType::Model,
@ -1581,6 +1584,7 @@ pub async fn run(
}; };
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}");
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{ {
HubTokenizerConfig::from_file(filename) HubTokenizerConfig::from_file(filename)