mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 04:22:08 +00:00
Fixing some corner cases.
This commit is contained in:
parent
0b710f9671
commit
3e731a7c2f
@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
[workspace.dependencies]
|
||||
base64 = "0.22.0"
|
||||
tokenizers = { version = "0.20.0", features = ["http"] }
|
||||
hf-hub = { version = "0.4.1", features = ["tokio"] }
|
||||
hf-hub = { version = "0.4.2", features = ["tokio"] }
|
||||
metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
|
@ -151,6 +151,7 @@ fn create_event_from_stream_token(
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum StreamState {
|
||||
/// Before the tools was parsed
|
||||
Buffering,
|
||||
@ -200,6 +201,7 @@ impl ChatState {
|
||||
pub fn push(&mut self, mut stream_token: StreamResponse) -> Vec<CompletionType> {
|
||||
let mut events = vec![];
|
||||
let token_text = &stream_token.token.text;
|
||||
println!("Got {token_text:?} - State {:?}", self.state);
|
||||
match self.state {
|
||||
StreamState::Buffering => {
|
||||
self.text.push_str(token_text);
|
||||
@ -223,9 +225,9 @@ impl ChatState {
|
||||
// XXX Caution, here we do not postfix the quote, so that the current output
|
||||
// Is necessarily finished with quotes for us to be able to parse.
|
||||
let partial = &self.text;
|
||||
let partial = partial.trim_end();
|
||||
let partial = partial.trim_end_matches(',');
|
||||
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
|
||||
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
|
||||
// This can be no_tool before the content has been emitted
|
||||
if call.function._name != "no_tool" {
|
||||
stream_token.token.text = "{".to_string();
|
||||
let chat_complete = create_event_from_stream_token(
|
||||
@ -279,30 +281,35 @@ impl ChatState {
|
||||
StreamState::NoToolFinish => {}
|
||||
StreamState::NoTool => {
|
||||
self.text.push_str(token_text);
|
||||
if token_text.contains("\"") || token_text.contains("}") {
|
||||
let total_text = &self.text;
|
||||
let total_text = total_text.trim_end();
|
||||
let total_text = total_text.trim_end_matches('}');
|
||||
let total_text = total_text.trim_end();
|
||||
let total_text = total_text.trim_end_matches('"');
|
||||
if let Ok(value) =
|
||||
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", total_text))
|
||||
{
|
||||
if !value.function.content.is_empty() {
|
||||
let text = token_text.trim_end();
|
||||
let text = text.trim_end_matches('}');
|
||||
let mut text = text.trim_end();
|
||||
if token_text.contains("\"") {
|
||||
let mut text = self
|
||||
.text
|
||||
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
|
||||
// Trim once
|
||||
if text.ends_with("\"") {
|
||||
// Verify we have actually trimmed something
|
||||
// The opposite can happen if the model is outputting inline JSON.
|
||||
text = &text[..text.len() - 1];
|
||||
if let Ok(_value) =
|
||||
serde_json::from_str::<NoTool>(&format!("{}\"}}}}", text))
|
||||
{
|
||||
let mut text = token_text
|
||||
.trim_end_matches(|c: char| c.is_whitespace() || c == '}');
|
||||
// Effectively trim_end_match('"', 1)
|
||||
// because we do not want to eventually trim finishing escaped quotes
|
||||
// {{"\"Something\""}}
|
||||
if text.ends_with("\"") {
|
||||
text = &text[..text.len() - 1];
|
||||
}
|
||||
println!("Detected end of content {text:?}");
|
||||
stream_token.token.text = text.to_string();
|
||||
self.state = StreamState::NoToolFinish;
|
||||
println!("NNew state {:?}", self.state);
|
||||
}
|
||||
}
|
||||
}
|
||||
// This escaping is usually inline json escaping and we can therefore remove it.
|
||||
stream_token.token.text = stream_token.token.text.replace("\\", "");
|
||||
let chat_complete = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
self.logprobs,
|
||||
@ -372,6 +379,52 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn get_text_content(event: &CompletionType) -> &String {
|
||||
match event {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||
assert_eq!(choices.len(), 1);
|
||||
if let ChatCompletionChoice {
|
||||
delta: ChatCompletionDelta::Chat(TextMessage { content, .. }),
|
||||
..
|
||||
} = &choices[0]
|
||||
{
|
||||
content
|
||||
} else {
|
||||
panic!("Expected plain message");
|
||||
}
|
||||
}
|
||||
_ => panic!("Unexpected chunk"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
|
||||
match event {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||
assert_eq!(choices.len(), 1);
|
||||
if let ChatCompletionChoice {
|
||||
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
|
||||
..
|
||||
} = &choices[0]
|
||||
{
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
let DeltaToolCall {
|
||||
index,
|
||||
id,
|
||||
r#type,
|
||||
function,
|
||||
} = &tool_calls[0];
|
||||
assert_eq!(*index, 0);
|
||||
assert_eq!(id, "");
|
||||
assert_eq!(r#type, "function");
|
||||
(function.name.as_ref(), &function.arguments)
|
||||
} else {
|
||||
panic!("Expected plain message");
|
||||
}
|
||||
}
|
||||
_ => panic!("Unexpected chunk"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream() {
|
||||
let mut chat_state = ChatState::new(
|
||||
@ -518,6 +571,83 @@ mod tests {
|
||||
"}".to_string(),
|
||||
"}".to_string(),
|
||||
];
|
||||
let tokens: Vec<_> = tokens
|
||||
.into_iter()
|
||||
.map(|text| StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initial ignored output
|
||||
for token in &tokens[..14] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0);
|
||||
}
|
||||
|
||||
// No tool output
|
||||
let mut output = String::new();
|
||||
for token in &tokens[14..14 + 7] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 1);
|
||||
let content = get_text_content(&events[0]);
|
||||
output.push_str(content);
|
||||
}
|
||||
|
||||
assert_eq!(output, "I am a helpful assistant!");
|
||||
|
||||
// No tool finish
|
||||
for token in &tokens[14 + 7..] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_tool_no_tool_many_quotes() {
|
||||
let mut chat_state = ChatState::new(
|
||||
true,
|
||||
StreamOptions {
|
||||
include_usage: true,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
"{\"".to_string(),
|
||||
"function".to_string(),
|
||||
"\":".to_string(),
|
||||
" {\"".to_string(),
|
||||
"_".to_string(),
|
||||
"name".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(),
|
||||
"no".to_string(),
|
||||
"_tool".to_string(),
|
||||
"\",".to_string(),
|
||||
" \"".to_string(),
|
||||
"content".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(), // Token 14
|
||||
"I".to_string(), // Event 1
|
||||
" am".to_string(), // Event 2
|
||||
" a".to_string(), // Event 3
|
||||
" helpful".to_string(), // Event 4
|
||||
" assistant".to_string(), // Event 5
|
||||
"!\\\"\"".to_string(), // Extra inside the string quote that would get removed
|
||||
"}".to_string(),
|
||||
"}".to_string(),
|
||||
];
|
||||
|
||||
// Initial ignored output
|
||||
for text in &tokens[..14] {
|
||||
@ -569,7 +699,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(output, "I am a helpful assistant!");
|
||||
assert_eq!(output, "I am a helpful assistant!\"");
|
||||
|
||||
// No tool finish
|
||||
for text in &tokens[14 + 7..] {
|
||||
@ -589,6 +719,157 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_tool_no_tool_inline_json() {
|
||||
let mut chat_state = ChatState::new(
|
||||
true,
|
||||
StreamOptions {
|
||||
include_usage: true,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
"{\"".to_string(),
|
||||
"function".to_string(),
|
||||
"\":".to_string(),
|
||||
" {\"".to_string(),
|
||||
"_".to_string(),
|
||||
"name".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(),
|
||||
"no".to_string(),
|
||||
"_tool".to_string(),
|
||||
"\",".to_string(),
|
||||
" \"".to_string(),
|
||||
"content".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(), // Token 14
|
||||
"{\\\"".to_string(), // Event 1
|
||||
"a".to_string(), // Event 1
|
||||
"\\\":".to_string(), // Event 1
|
||||
"2".to_string(), // Event 2
|
||||
",\\".to_string(), // Event 2
|
||||
"\"".to_string(), // Event 2
|
||||
"b".to_string(), // Event 3
|
||||
"\\\": ".to_string(), // Event 4
|
||||
"1".to_string(), // Event 5
|
||||
"}".to_string(), // Event 5
|
||||
"\"}".to_string(), // Extra inside the string quote that would get removed
|
||||
"}".to_string(),
|
||||
];
|
||||
let tokens: Vec<_> = tokens
|
||||
.into_iter()
|
||||
.map(|text| StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initial ignored output
|
||||
for token in &tokens[..14] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0);
|
||||
}
|
||||
|
||||
// No tool output
|
||||
let mut output = String::new();
|
||||
for token in &tokens[14..14 + 12] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 1, "Current text is {output:?}");
|
||||
let content = get_text_content(&events[0]);
|
||||
output.push_str(content);
|
||||
}
|
||||
|
||||
assert_eq!(output, "{\"a\":2,\"b\": 1}");
|
||||
|
||||
// No tool finish
|
||||
for token in &tokens[14 + 12..] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0, "Extra events {events:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_tool_no_tool_empty() {
|
||||
let mut chat_state = ChatState::new(
|
||||
true,
|
||||
StreamOptions {
|
||||
include_usage: true,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
"{\"".to_string(),
|
||||
"function".to_string(),
|
||||
"\":".to_string(),
|
||||
" {\"".to_string(),
|
||||
"_".to_string(),
|
||||
"name".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(),
|
||||
"no".to_string(),
|
||||
"_tool".to_string(),
|
||||
"\",".to_string(),
|
||||
" \"".to_string(),
|
||||
"content".to_string(),
|
||||
"\":\"".to_string(),
|
||||
"\"}".to_string(), // Token 13
|
||||
"}".to_string(), // Event 1
|
||||
];
|
||||
let tokens: Vec<_> = tokens
|
||||
.into_iter()
|
||||
.map(|text| StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initial ignored output
|
||||
for token in &tokens[..13] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0);
|
||||
}
|
||||
|
||||
// No tool output
|
||||
let mut output = String::new();
|
||||
for token in &tokens[13..13 + 2] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 1, "Current text is {output:?}");
|
||||
let content = get_text_content(&events[0]);
|
||||
output.push_str(content);
|
||||
}
|
||||
|
||||
assert_eq!(output, "");
|
||||
|
||||
// No tool finish
|
||||
for token in &tokens[13 + 2..] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0, "Extra events {events:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_tool_get_weather() {
|
||||
let mut chat_state = ChatState::new(
|
||||
@ -633,10 +914,9 @@ mod tests {
|
||||
"elsius".to_string(), // Event 17
|
||||
"\"}}".to_string(), // Event 18 retained (trailing brace removed)
|
||||
];
|
||||
|
||||
// Initial ignored output
|
||||
for text in &tokens[..11] {
|
||||
let events = chat_state.push(StreamResponse {
|
||||
let tokens: Vec<_> = tokens
|
||||
.into_iter()
|
||||
.map(|text| StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
@ -647,56 +927,27 @@ mod tests {
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
});
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initial ignored output
|
||||
for token in &tokens[..11] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0, "{events:?}");
|
||||
}
|
||||
|
||||
// No tool output
|
||||
let mut output = String::new();
|
||||
let mut output_name = String::new();
|
||||
for text in &tokens[11..11 + 17] {
|
||||
let events = chat_state.push(StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
});
|
||||
for token in &tokens[11..11 + 17] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 1);
|
||||
match &events[0] {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||
assert_eq!(choices.len(), 1);
|
||||
if let ChatCompletionChoice {
|
||||
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
|
||||
..
|
||||
} = &choices[0]
|
||||
{
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
let DeltaToolCall {
|
||||
index,
|
||||
id,
|
||||
r#type,
|
||||
function,
|
||||
} = &tool_calls[0];
|
||||
assert_eq!(*index, 0);
|
||||
assert_eq!(id, "");
|
||||
assert_eq!(r#type, "function");
|
||||
if let Some(name) = &function.name {
|
||||
assert_eq!(name, "get_current_weather");
|
||||
output_name.push_str(&name);
|
||||
}
|
||||
output.push_str(&function.arguments);
|
||||
} else {
|
||||
panic!("Expected plain message");
|
||||
}
|
||||
}
|
||||
_ => panic!("Unexpected chunk"),
|
||||
let (name, arguments) = get_tool_call_content(&events[0]);
|
||||
if let Some(name) = name {
|
||||
assert_eq!(name, "get_current_weather");
|
||||
output_name.push_str(&name);
|
||||
}
|
||||
output.push_str(arguments);
|
||||
}
|
||||
|
||||
assert_eq!(output_name, "get_current_weather");
|
||||
@ -706,19 +957,8 @@ mod tests {
|
||||
);
|
||||
|
||||
// No tool finish
|
||||
for text in &tokens[11 + 17..] {
|
||||
let events = chat_state.push(StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
});
|
||||
for token in &tokens[11 + 17..] {
|
||||
let events = chat_state.push(token.clone());
|
||||
assert_eq!(events.len(), 0);
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Erro
|
||||
Ok(Local::now().format(&format_str).to_string())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
|
@ -52,7 +52,7 @@ pub struct Infer {
|
||||
/// Request backend
|
||||
backend: Arc<dyn Backend + Send + Sync>,
|
||||
/// Chat template
|
||||
chat_template: Option<ChatTemplate>,
|
||||
pub(crate) chat_template: Option<ChatTemplate>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
/// Backend health
|
||||
|
@ -1162,6 +1162,8 @@ pub(crate) async fn chat_completions(
|
||||
logprobs,
|
||||
..
|
||||
} = chat.clone();
|
||||
|
||||
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.try_into_generate(&infer)?;
|
||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||
@ -1565,6 +1567,7 @@ pub async fn run(
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
tracing::info!("Cache {cache:?}");
|
||||
let repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
@ -1581,6 +1584,7 @@ pub async fn run(
|
||||
};
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}");
|
||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
{
|
||||
HubTokenizerConfig::from_file(filename)
|
||||
|
Loading…
Reference in New Issue
Block a user