diff --git a/router/src/server.rs b/router/src/server.rs index dcbaa2ad..d718513f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -115,6 +115,87 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/templated", + request_body = ChatRequest, + responses((status = 200, description = "Templated Chat Request", body = Value)) +)] +async fn get_templated( + Extension(infer): Extension, + Json(req): Json, +) -> Result)> { + metrics::counter!("tgi_request_count").increment(1); + + let ChatRequest { + messages, + response_format, + tools, + tool_choice, + tool_prompt, + .. + } = req; + + if response_format.is_some() && tools.is_some() { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Grammar and tools are mutually exclusive".to_string(), + error_type: "validation".to_string(), + }), + )); + } + + let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { + Ok(grammar) => grammar, + Err(err) => { + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); + tracing::error!("{}", err); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + error_type: err.error_type().to_string(), + }), + )); + } + }; + + let tools_grammar_prompt = tool_grammar.as_ref().map(|t| { + ( + GrammarType::Json(serde_json::json!(t)), + tool_prompt.unwrap_or_default(), + ) + }); + + let (tools_grammar_prompt, _grammar) = response_format + .map(|rf| (None, Some(rf))) + .unwrap_or_else(|| { + ( + tools_grammar_prompt.clone(), + tools_grammar_prompt.map(|(g, _)| g), + ) + }); + + let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { + Ok(inputs) => inputs, + Err(err) => { + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); + tracing::error!("{}", err); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + error_type: err.error_type().to_string(), + }), + )); + } + }; + + Ok((HeaderMap::new(), Json(inputs)).into_response()) +} + #[utoipa::path( get, tag = "Text Generation Inference", @@ -2036,6 +2117,7 @@ async fn start( } let info_routes = Router::new() .route("/", get(health)) + .route("/templated", post(get_templated)) .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health))