diff --git a/proto/generate.proto b/proto/generate.proto index 1c252599..0490029f 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -55,7 +55,6 @@ enum GrammarType { GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_JSON = 1; GRAMMAR_TYPE_REGEX = 2; - GRAMMAR_TYPE_OPTIONAL_JSON = 3; } message NextTokenChooserParameters { diff --git a/router/src/infer.rs b/router/src/infer.rs index 472b7d66..42405327 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -812,23 +812,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -877,28 +881,33 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "Hi again!".to_string(), + content: Some("Hi again!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -952,23 +961,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -1006,23 +1019,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), diff --git a/router/src/lib.rs b/router/src/lib.rs index 66213975..5dfa8c7d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -358,10 +358,11 @@ impl ChatCompletion { pub(crate) fn new( model: String, system_fingerprint: String, - output: String, + output: Option, created: u64, details: Details, return_logprobs: bool, + tool_calls: Option, ) -> Self { Self { id: String::new(), @@ -375,6 +376,7 @@ impl ChatCompletion { role: "assistant".into(), content: output, name: None, + tool_calls, }, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), @@ -527,10 +529,61 @@ pub(crate) struct ChatRequest { #[schema(nullable = true, example = "null")] pub tools: Option>, + /// A prompt to be appended before the tools + #[serde(default = "default_tool_prompt")] + pub tool_prompt: Option, + /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - pub tool_choice: Option, + #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] + pub tool_choice: Option, +} + +fn default_tool_prompt() -> Option { + Some( + "\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), + ) +} +#[derive(Clone, Deserialize, ToSchema, Serialize)] +enum ToolType { + FunctionName(String), + OneOf, +} + +/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) +mod deserialize_tool_choice { + use super::*; + use serde::de; + use serde::Deserializer; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => match s.as_str() { + "none" => Ok(None), + "auto" => Ok(Some(ToolType::OneOf)), + _ => Ok(Some(ToolType::FunctionName(s))), + }, + Value::Object(map) => { + if let Some(content) = map + .get("function") + .and_then(|v| v.get("name")) + .and_then(|v| v.as_str()) + { + Ok(Some(ToolType::FunctionName(content.to_string()))) + } else { + Err(de::Error::custom("function key not found in tool choice")) + } + } + _ => Err(de::Error::custom("invalid token format")), + } + } } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] @@ -575,7 +628,8 @@ impl FunctionRef { #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] pub(crate) struct Function { - pub description: String, + #[serde(default)] + pub description: Option, pub name: String, pub parameters: serde_json::Value, } @@ -597,15 +651,24 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, } +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct ToolCall { + pub id: u32, + pub r#type: String, + pub function: Function, +} + #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct Message { #[schema(example = "user")] pub role: String, #[schema(example = "My name is David and I")] - pub content: String, + pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] pub name: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 2d4944f0..8d639624 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -10,7 +10,7 @@ use crate::{ HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, }; -use crate::{FunctionRef, Tools}; +use crate::{Function, FunctionRef, ToolCall, ToolType, Tools}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -583,6 +583,16 @@ async fn chat_completions( let logprobs = req.logprobs.unwrap_or(false); let seed = req.seed; + if stream && req.tools.is_some() { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Tools are not supported with stream".to_string(), + error_type: "Input validation error".to_string(), + }), + )); + } + // apply chat template to flatten the request into a single input let mut inputs = match infer.apply_chat_template(req.messages) { Ok(inputs) => inputs, @@ -599,53 +609,35 @@ async fn chat_completions( } }; - // if theres a tools object, we need to decompose it and use the function name as the key - // and the parameters as the value in the "$functions" object. - let grammar = if let Some(ref req_tools) = &req.tools { - // get the tool_choice if there is one - let tool_choice = &req.tool_choice; - let tools_to_use = if let Some(tool_choice) = tool_choice { - // get the tool based on the tool_choice - let tool = req_tools - .iter() - .find(|tool| tool.function.name == *tool_choice) - .ok_or_else(|| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Input validation error".to_string(), - error_type: "Input validation error".to_string(), - }), - ) - })?; - vec![tool.clone()] - } else { - req_tools.clone() - }; - - let functions: HashMap = { - let mut tools = HashMap::new(); - for tool in &tools_to_use { - let func = tool.function.clone(); - let name = func.name; - let parameters = match func.parameters.as_object() { - Some(parameters) => parameters.clone(), - None => { - return Err(( + let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { + let tool_prompt = req.tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .ok_or_else(|| { + ( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { error: "Input validation error".to_string(), error_type: "Input validation error".to_string(), }), - )) - } - }; - - tools.insert(name, Value::Object(parameters)); + ) + })? + .clone()] } - tools + ToolType::OneOf => req_tools.to_owned(), }; + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + (func.name, func.parameters) + }) + .collect(); + let tools = Tools { function: functions, any_of: tools_to_use @@ -654,7 +646,6 @@ async fn chat_completions( .collect(), }; - // update the input let tools_str = serde_json::to_string(&tools).map_err(|e| { ( StatusCode::UNPROCESSABLE_ENTITY, @@ -664,12 +655,7 @@ async fn chat_completions( }), ) })?; - - let tool_prompt = - "Based on the conversation, please choose the most appropriate tool to use:" - .to_string(); - inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n"); - + inputs = format!("{inputs}{tool_prompt}{tools_str}"); Some(GrammarType::Json(tools.into())) } else { None @@ -696,7 +682,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, - grammar, + grammar: tool_grammar.clone(), }, }; @@ -760,14 +746,41 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); + let (tool_calls, output) = if tool_grammar.is_some() { + // gen_text should be valid json + let gen_text_value: Value = + serde_json::from_str(&generation.generated_text).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + })?; + + let tool_call = Some(ToolCall { + id: 0, + r#type: "function".to_string(), + function: Function { + description: None, + name: "tools".to_string(), + parameters: gen_text_value.get("function").unwrap().clone(), + }, + }); + (tool_call, None) + } else { + (None, Some(generation.generated_text)) + }; // build the complete response object with the full text let response = ChatCompletion::new( model_id, system_fingerprint, - generation.generated_text, + output, current_time, generation.details.unwrap(), logprobs, + tool_calls, ); // wrap generation inside a Vec to match api-inference diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index c0ccd83f..40f31ce2 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -513,9 +513,6 @@ class GrammarLogitProcessor(LogitsProcessor): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: schema = build_regex_from_object(schema) - elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON: - # TODO: use a better method to handle optional grammars - schema = f"({build_regex_from_object(schema)})|.*" elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer)