diff --git a/router/src/lib.rs b/router/src/lib.rs index 2e412f1a..e2f21bc7 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -399,33 +399,23 @@ impl From<(Token, Vec)> for ChatCompletionLogprobs { impl From<(Vec, Vec>)> for ChatCompletionLogprobs { fn from(value: (Vec, Vec>)) -> Self { let (tokens, top_tokens) = value; - - // Create an iterator that produces None for top_tokens once it's exhausted - let top_tokens_iter = top_tokens - .into_iter() - .map(Some) - .chain(std::iter::repeat(None)); - - let content = tokens - .into_iter() - .zip(top_tokens_iter) - .map(|(t, top_t_option)| ChatCompletionLogprob { - token: t.text, - logprob: t.logprob, - top_logprobs: match top_t_option { - Some(top_t) => top_t + Self { + content: tokens + .into_iter() + .zip(top_tokens) + .map(|(t, top_t)| ChatCompletionLogprob { + token: t.text, + logprob: t.logprob, + top_logprobs: top_t .into_iter() .map(|t| ChatCompletionTopLogprob { token: t.text, logprob: t.logprob, }) .collect(), - None => vec![], // Handle the case where there are no top tokens - }, - }) - .collect(); - - Self { content } + }) + .collect(), + } } } @@ -727,26 +717,26 @@ mod deserialize_tool_choice { } } -#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] pub struct Tools { #[serde(flatten)] functions_map: FunctionsMap, properties: Properties, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] struct FunctionsMap { #[serde(rename = "$functions")] functions: std::collections::HashMap, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] struct FunctionRef { #[serde(rename = "$ref")] ref_path: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] struct Properties { #[serde(serialize_with = "serialize_function")] function: Vec, @@ -767,7 +757,8 @@ pub(crate) struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, - pub parameters: serde_json::Value, + #[serde(alias = "parameters")] + pub arguments: serde_json::Value, } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index b8f93514..17fcedd6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,7 @@ use crate::config::Config; /// HTTP Server logic use crate::health::Health; -use crate::infer::{InferError, InferResponse, InferStreamResponse}; +use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, @@ -15,7 +15,7 @@ use crate::{ ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; +use crate::{FunctionDefinition, ToolCall, ToolType}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -29,7 +29,6 @@ use futures::Stream; use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; -use std::collections::HashMap; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; @@ -766,6 +765,7 @@ async fn chat_completions( let logprobs = req.logprobs.unwrap_or(false); let seed = req.seed; let stop = req.stop.unwrap_or_default(); + let tool_prompt = req.tool_prompt.unwrap_or_default(); // apply chat template to flatten the request into a single input let mut inputs = match infer.apply_chat_template(req.messages) { @@ -783,47 +783,22 @@ async fn chat_completions( } }; - let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { - let tool_prompt = req.tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .ok_or_else(|| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Tool choice not found in tool names".to_string(), - error_type: "Tool not found".to_string(), - }), - ) - })? - .clone()] - } - ToolType::OneOf => req_tools.to_owned(), - }; - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - (func.name, func.parameters) - }) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .collect(), - }, - }; + let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) { + Ok(grammar) => grammar, + Err(err) => { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + error_type: err.error_type().to_string(), + }), + )); + } + }; + let grammar = if let Some(tools) = &tool_grammar { let tools_str = serde_json::to_string(&tools).map_err(|e| { ( StatusCode::UNPROCESSABLE_ENTITY, @@ -834,7 +809,7 @@ async fn chat_completions( ) })?; inputs = format!("{inputs}{tool_prompt}{tools_str}"); - Some(GrammarType::Json(serde_json::json!(tools))) + Some(GrammarType::Json(serde_json::to_value(tools).unwrap())) } else { None }; @@ -860,7 +835,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: req.top_logprobs, - grammar: tool_grammar.clone(), + grammar, }, }; @@ -949,21 +924,23 @@ async fn chat_completions( r#type: "function".to_string(), function: FunctionDefinition { description: None, - name: "tools".to_string(), - parameters: gen_text_value.get("function").map_or_else( - || { - serde_json::from_str(&generation.generated_text).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - }) - }, - |f| Ok(f.clone()), - )?, + name: gen_text_value + .get("function") + .and_then(|f| f.get("_name")) + .and_then(|name| name.as_str()) + .unwrap_or("default_function_name") + .to_string(), + // Serialize the JSON object obtained from "function" to an escaped JSON string + arguments: gen_text_value + .get("function") + .map(|f| { + let mut f_cloned = f.clone(); + if let Value::Object(ref mut props) = f_cloned { + props.remove("_name"); + } + f_cloned + }) + .unwrap_or_default(), }, }]; (Some(tool_calls), None) @@ -1539,6 +1516,7 @@ impl From for (StatusCode, Json) { InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, }; (