Fixing some corner cases.

This commit is contained in:
Nicolas Patry 2025-03-10 12:06:44 +01:00
parent 0b710f9671
commit 3e731a7c2f
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
5 changed files with 322 additions and 78 deletions

View File

@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.4.1", features = ["tokio"] }
hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }

View File

@ -151,6 +151,7 @@ fn create_event_from_stream_token(
))
}
#[derive(Debug)]
enum StreamState {
/// Before the tools was parsed
Buffering,
@ -200,6 +201,7 @@ impl ChatState {
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
let mut events = vec![];
let token_text = &stream_token.token.text;
println!("Got {token_text:?} - State {:?}", self.state);
match self.state {
StreamState::Buffering => {
self.text.push_str(token_text);
@ -223,9 +225,9 @@ impl ChatState {
// XXX Caution, here we do not postfix the quote, so that the current output
// Is necessarily finished with quotes for us to be able to parse.
let partial = &self.text;
let partial = partial.trim_end();
let partial = partial.trim_end_matches(',');
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
// This can be no_tool before the content has been emitted
if call.function._name != "no_tool" {
stream_token.token.text = "{".to_string();
let chat_complete = create_event_from_stream_token(
@ -279,30 +281,35 @@ impl ChatState {
StreamState::NoToolFinish => {}
StreamState::NoTool => {
self.text.push_str(token_text);
if token_text.contains("\"") || token_text.contains("}") {
let total_text = &self.text;
let total_text = total_text.trim_end();
let total_text = total_text.trim_end_matches('}');
let total_text = total_text.trim_end();
let total_text = total_text.trim_end_matches('"');
if let Ok(value) =
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", total_text))
{
if !value.function.content.is_empty() {
let text = token_text.trim_end();
let text = text.trim_end_matches('}');
let mut text = text.trim_end();
if token_text.contains("\"") {
let mut text = self
.text
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
// Trim once
if text.ends_with("\"") {
// Verify we have actually trimmed something
// The opposite can happen if the model is outputting inline JSON.
text = &text[..text.len() - 1];
if let Ok(_value) =
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", text))
{
let mut text = token_text
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
// Effectively trim_end_match('"', 1)
// because we do not want to eventually trim finishing escaped quotes
// {{"\"Something\""}}
if text.ends_with("\"") {
text = &text[..text.len() - 1];
}
println!("Detected end of content {text:?}");
stream_token.token.text = text.to_string();
self.state = StreamState::NoToolFinish;
println!("NNew state {:?}", self.state);
}
}
}
// This escaping is usually inline json escaping and we can therefore remove it.
stream_token.token.text = stream_token.token.text.replace("\\", "");
let chat_complete = create_event_from_stream_token(
&stream_token,
self.logprobs,
@ -372,6 +379,52 @@ mod tests {
use super::*;
fn get_text_content(event: &CompletionType) -> &String {
match event {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Chat(TextMessage { content, .. }),
..
} = &choices[0]
{
content
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
}
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
match event {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
..
} = &choices[0]
{
assert_eq!(tool_calls.len(), 1);
let DeltaToolCall {
index,
id,
r#type,
function,
} = &tool_calls[0];
assert_eq!(*index, 0);
assert_eq!(id, "");
assert_eq!(r#type, "function");
(function.name.as_ref(), &function.arguments)
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
}
#[test]
fn test_chat_stream() {
let mut chat_state = ChatState::new(
@ -518,6 +571,83 @@ mod tests {
"}".to_string(),
"}".to_string(),
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..14] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for token in &tokens[14..14 + 7] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1);
let content = get_text_content(&events[0]);
output.push_str(content);
}
assert_eq!(output, "I am a helpful assistant!");
// No tool finish
for token in &tokens[14 + 7..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
}
#[test]
fn test_chat_stream_tool_no_tool_many_quotes() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":".to_string(),
" \"".to_string(), // Token 14
"I".to_string(), // Event 1
" am".to_string(), // Event 2
" a".to_string(), // Event 3
" helpful".to_string(), // Event 4
" assistant".to_string(), // Event 5
"!\\\"\"".to_string(), // Extra inside the string quote that would get removed
"}".to_string(),
"}".to_string(),
];
// Initial ignored output
for text in &tokens[..14] {
@ -569,7 +699,7 @@ mod tests {
}
}
assert_eq!(output, "I am a helpful assistant!");
assert_eq!(output, "I am a helpful assistant!\"");
// No tool finish
for text in &tokens[14 + 7..] {
@ -589,6 +719,157 @@ mod tests {
}
}
#[test]
fn test_chat_stream_tool_no_tool_inline_json() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":".to_string(),
" \"".to_string(), // Token 14
"{\\\"".to_string(), // Event 1
"a".to_string(), // Event 1
"\\\":".to_string(), // Event 1
"2".to_string(), // Event 2
",\\".to_string(), // Event 2
"\"".to_string(), // Event 2
"b".to_string(), // Event 3
"\\\": ".to_string(), // Event 4
"1".to_string(), // Event 5
"}".to_string(), // Event 5
"\"}".to_string(), // Extra inside the string quote that would get removed
"}".to_string(),
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..14] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for token in &tokens[14..14 + 12] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1, "Current text is {output:?}");
let content = get_text_content(&events[0]);
output.push_str(content);
}
assert_eq!(output, "{\"a\":2,\"b\": 1}");
// No tool finish
for token in &tokens[14 + 12..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "Extra events {events:?}");
}
}
#[test]
fn test_chat_stream_tool_no_tool_empty() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":\"".to_string(),
"\"}".to_string(), // Token 13
"}".to_string(), // Event 1
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..13] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
// No tool output
let mut output = String::new();
for token in &tokens[13..13 + 2] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1, "Current text is {output:?}");
let content = get_text_content(&events[0]);
output.push_str(content);
}
assert_eq!(output, "");
// No tool finish
for token in &tokens[13 + 2..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "Extra events {events:?}");
}
}
#[test]
fn test_chat_stream_tool_get_weather() {
let mut chat_state = ChatState::new(
@ -633,10 +914,9 @@ mod tests {
"elsius".to_string(), // Event 17
"\"}}".to_string(), // Event 18 retained (trailing brace removed)
];
// Initial ignored output
for text in &tokens[..11] {
let events = chat_state.push(StreamResponse {
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
@ -647,56 +927,27 @@ mod tests {
top_tokens: vec![],
index: 0,
details: None,
});
})
.collect();
// Initial ignored output
for token in &tokens[..11] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0, "{events:?}");
}
// No tool output
let mut output = String::new();
let mut output_name = String::new();
for text in &tokens[11..11 + 17] {
let events = chat_state.push(StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
for token in &tokens[11..11 + 17] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 1);
match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
..
} = &choices[0]
{
assert_eq!(tool_calls.len(), 1);
let DeltaToolCall {
index,
id,
r#type,
function,
} = &tool_calls[0];
assert_eq!(*index, 0);
assert_eq!(id, "");
assert_eq!(r#type, "function");
if let Some(name) = &function.name {
assert_eq!(name, "get_current_weather");
output_name.push_str(&name);
}
output.push_str(&function.arguments);
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name {
assert_eq!(name, "get_current_weather");
output_name.push_str(&name);
}
output.push_str(arguments);
}
assert_eq!(output_name, "get_current_weather");
@ -706,19 +957,8 @@ mod tests {
);
// No tool finish
for text in &tokens[11 + 17..] {
let events = chat_state.push(StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
for token in &tokens[11 + 17..] {
let events = chat_state.push(token.clone());
assert_eq!(events.len(), 0);
}
}

View File

@ -16,7 +16,7 @@ pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Erro
Ok(Local::now().format(&format_str).to_string())
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub(crate) struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,

View File

@ -52,7 +52,7 @@ pub struct Infer {
/// Request backend
backend: Arc<dyn Backend + Send + Sync>,
/// Chat template
chat_template: Option<ChatTemplate>,
pub(crate) chat_template: Option<ChatTemplate>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Backend health

View File

@ -1162,6 +1162,8 @@ pub(crate) async fn chat_completions(
logprobs,
..
} = chat.clone();
tracing::debug!("Got chat_template {:?}", infer.chat_template);
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters));
@ -1565,6 +1567,7 @@ pub async fn run(
)
}
Type::Cache(cache) => {
tracing::info!("Cache {cache:?}");
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
@ -1581,6 +1584,7 @@ pub async fn run(
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}");
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)