mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 21:42:06 +00:00
Removing a lot of NO_TOOL shenanigans.
This commit is contained in:
parent
cb92acf280
commit
03fe626a95
@ -19,14 +19,18 @@ struct Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(test, derive(Debug))]
|
#[cfg_attr(test, derive(Debug))]
|
||||||
pub(crate) enum ChatEvent{
|
pub(crate) enum ChatEvent {
|
||||||
NoTool,
|
NoTool,
|
||||||
Events(Vec<CompletionType>)
|
Events(Vec<CompletionType>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn parse_output(
|
#[cfg_attr(test, derive(Debug))]
|
||||||
generated_text: &str,
|
pub(crate) enum ChatChoice {
|
||||||
) -> Result<(Option<Vec<crate::ToolCall>>, Option<String>), InferError> {
|
NoTool,
|
||||||
|
ToolCalls(Vec<crate::ToolCall>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {
|
||||||
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
|
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
|
||||||
InferError::ToolError(format!(
|
InferError::ToolError(format!(
|
||||||
"Failed to parse generated text: {} {:?}",
|
"Failed to parse generated text: {} {:?}",
|
||||||
@ -38,16 +42,7 @@ pub(crate) fn parse_output(
|
|||||||
match &name[..] {
|
match &name[..] {
|
||||||
"no_tool" => {
|
"no_tool" => {
|
||||||
// parse the content message
|
// parse the content message
|
||||||
let content_message = call
|
Ok(ChatChoice::NoTool)
|
||||||
.function
|
|
||||||
.arguments
|
|
||||||
.get("content")
|
|
||||||
.and_then(Value::as_str)
|
|
||||||
.ok_or_else(|| {
|
|
||||||
InferError::ToolError("No `content` found in generated text".to_string())
|
|
||||||
})?
|
|
||||||
.to_string();
|
|
||||||
Ok((None, Some(content_message)))
|
|
||||||
}
|
}
|
||||||
name => {
|
name => {
|
||||||
let tool_calls = vec![crate::ToolCall {
|
let tool_calls = vec![crate::ToolCall {
|
||||||
@ -63,7 +58,7 @@ pub(crate) fn parse_output(
|
|||||||
})?,
|
})?,
|
||||||
},
|
},
|
||||||
}];
|
}];
|
||||||
Ok((Some(tool_calls), None))
|
Ok(ChatChoice::ToolCalls(tool_calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -194,8 +189,10 @@ impl ChatState {
|
|||||||
match self.state {
|
match self.state {
|
||||||
StreamState::Buffering => {
|
StreamState::Buffering => {
|
||||||
self.text.push_str(token_text);
|
self.text.push_str(token_text);
|
||||||
|
tracing::info!("Current text {:?}", self.text);
|
||||||
let partial = &self.text;
|
let partial = &self.text;
|
||||||
let partial = partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',');
|
let partial =
|
||||||
|
partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');
|
||||||
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
|
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
|
||||||
// This can be no_tool before the content has been emitted
|
// This can be no_tool before the content has been emitted
|
||||||
if call.function._name != "no_tool" {
|
if call.function._name != "no_tool" {
|
||||||
@ -212,7 +209,7 @@ impl ChatState {
|
|||||||
|
|
||||||
events.push(chat_complete);
|
events.push(chat_complete);
|
||||||
self.state = StreamState::Tool;
|
self.state = StreamState::Tool;
|
||||||
}else{
|
} else {
|
||||||
return ChatEvent::NoTool;
|
return ChatEvent::NoTool;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -362,7 +359,7 @@ mod tests {
|
|||||||
index: 0,
|
index: 0,
|
||||||
details: None,
|
details: None,
|
||||||
});
|
});
|
||||||
if let ChatEvent::Events(events) = events{
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 1);
|
assert_eq!(events.len(), 1);
|
||||||
match &events[0] {
|
match &events[0] {
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||||
@ -382,7 +379,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
_ => panic!("Unexpected chunk"),
|
_ => panic!("Unexpected chunk"),
|
||||||
}
|
}
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -417,7 +414,7 @@ mod tests {
|
|||||||
finish_reason: FinishReason::Length,
|
finish_reason: FinishReason::Length,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
if let ChatEvent::Events(events) = events{
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 2);
|
assert_eq!(events.len(), 2);
|
||||||
match &events[0] {
|
match &events[0] {
|
||||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||||
@ -451,7 +448,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
_ => panic!("Unexpected chunk"),
|
_ => panic!("Unexpected chunk"),
|
||||||
}
|
}
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -513,18 +510,18 @@ mod tests {
|
|||||||
// Initial ignored output
|
// Initial ignored output
|
||||||
for token in &tokens[..10] {
|
for token in &tokens[..10] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
if let ChatEvent::Events(events) = events{
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 0, "{events:?}");
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No tool output
|
// No tool output
|
||||||
let events = chat_state.push(tokens[10].clone());
|
let events = chat_state.push(tokens[10].clone());
|
||||||
if let ChatEvent::NoTool = events{
|
if let ChatEvent::NoTool = events {
|
||||||
assert!(true);
|
assert!(true);
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -579,18 +576,18 @@ mod tests {
|
|||||||
// Initial ignored output
|
// Initial ignored output
|
||||||
for token in &tokens[..10] {
|
for token in &tokens[..10] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
if let ChatEvent::Events(events) = events{
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 0, "{events:?}");
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No tool output
|
// No tool output
|
||||||
let events = chat_state.push(tokens[10].clone());
|
let events = chat_state.push(tokens[10].clone());
|
||||||
if let ChatEvent::NoTool = events{
|
if let ChatEvent::NoTool = events {
|
||||||
assert!(true);
|
assert!(true);
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -659,9 +656,9 @@ mod tests {
|
|||||||
// Initial ignored output
|
// Initial ignored output
|
||||||
for token in &tokens[..11] {
|
for token in &tokens[..11] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
if let ChatEvent::Events(events) = events{
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 0, "{events:?}");
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -671,7 +668,7 @@ mod tests {
|
|||||||
let mut output_name = String::new();
|
let mut output_name = String::new();
|
||||||
for token in &tokens[11..11 + 17] {
|
for token in &tokens[11..11 + 17] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
if let ChatEvent::Events(events) = events{
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 1);
|
assert_eq!(events.len(), 1);
|
||||||
let (name, arguments) = get_tool_call_content(&events[0]);
|
let (name, arguments) = get_tool_call_content(&events[0]);
|
||||||
if let Some(name) = name {
|
if let Some(name) = name {
|
||||||
@ -679,7 +676,7 @@ mod tests {
|
|||||||
output_name.push_str(&name);
|
output_name.push_str(&name);
|
||||||
}
|
}
|
||||||
output.push_str(arguments);
|
output.push_str(arguments);
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -693,9 +690,9 @@ mod tests {
|
|||||||
// No tool finish
|
// No tool finish
|
||||||
for token in &tokens[11 + 17..] {
|
for token in &tokens[11 + 17..] {
|
||||||
let events = chat_state.push(token.clone());
|
let events = chat_state.push(token.clone());
|
||||||
if let ChatEvent::Events(events) = events{
|
if let ChatEvent::Events(events) = events {
|
||||||
assert_eq!(events.len(), 0, "{events:?}");
|
assert_eq!(events.len(), 0, "{events:?}");
|
||||||
}else{
|
} else {
|
||||||
panic!("Expected chat events");
|
panic!("Expected chat events");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,13 +40,13 @@ impl ToolGrammar {
|
|||||||
),
|
),
|
||||||
arguments: json!({
|
arguments: json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
// "properties": {
|
||||||
"content": {
|
// "content": {
|
||||||
"type": "string",
|
// "type": "string",
|
||||||
"description": "The response content",
|
// "description": "The response content",
|
||||||
}
|
// }
|
||||||
},
|
// },
|
||||||
"required": ["content"]
|
// "required": ["content"]
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::chat::{ChatState, ChatEvent};
|
use crate::chat::{ChatChoice, ChatEvent, ChatState};
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
||||||
@ -1178,8 +1178,13 @@ pub(crate) async fn chat_completions(
|
|||||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
|
infer.clone(),
|
||||||
|
compute_type.clone(),
|
||||||
|
Json(generate_request),
|
||||||
|
span.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let response_stream = async_stream::stream! {
|
let response_stream = async_stream::stream! {
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
@ -1194,10 +1199,10 @@ pub(crate) async fn chat_completions(
|
|||||||
chat.response_format = None;
|
chat.response_format = None;
|
||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
chat.clone().try_into_generate(&infer).unwrap();
|
chat.clone().try_into_generate(&infer).unwrap();
|
||||||
assert_eq!(using_tools, false);
|
assert!(!using_tools);
|
||||||
let (_headers, response_stream2) =
|
let (_headers, response_stream2) =
|
||||||
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
|
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
|
||||||
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs.clone(), id.clone());
|
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
||||||
response_stream = Box::pin(response_stream2);
|
response_stream = Box::pin(response_stream2);
|
||||||
}
|
}
|
||||||
ChatEvent::Events(events) => {
|
ChatEvent::Events(events) => {
|
||||||
@ -1219,8 +1224,13 @@ pub(crate) async fn chat_completions(
|
|||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, input_length, Json(generation)) =
|
let (mut headers, mut input_length, Json(generation)) = generate_internal(
|
||||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
Extension(infer.clone()),
|
||||||
|
compute_type.clone(),
|
||||||
|
Json(generate_request),
|
||||||
|
span.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -1228,7 +1238,26 @@ pub(crate) async fn chat_completions(
|
|||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if using_tools {
|
let (tool_calls, output) = if using_tools {
|
||||||
crate::chat::parse_output(&generation.generated_text)?
|
match crate::chat::parse_output(&generation.generated_text)? {
|
||||||
|
ChatChoice::NoTool => {
|
||||||
|
chat.tools = None;
|
||||||
|
chat.response_format = None;
|
||||||
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
|
chat.clone().try_into_generate(&infer)?;
|
||||||
|
assert!(!using_tools);
|
||||||
|
let (headers_final, input_length_final, Json(generation)) = generate_internal(
|
||||||
|
Extension(infer),
|
||||||
|
compute_type,
|
||||||
|
Json(generate_request),
|
||||||
|
span,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
headers = headers_final;
|
||||||
|
input_length = input_length_final;
|
||||||
|
(None, Some(generation.generated_text))
|
||||||
|
}
|
||||||
|
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
(None, Some(generation.generated_text))
|
(None, Some(generation.generated_text))
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user