From c4258e40fea0bad3bc54208f61454a5b88274ce3 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 2 Aug 2024 23:35:28 +0000 Subject: [PATCH] feat: simplify prepare_chat_input logic and adjust start stop chars --- router/src/server.rs | 178 +++++++++++++++---------------------------- 1 file changed, 61 insertions(+), 117 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 35292410..7655182a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -144,61 +144,15 @@ async fn get_chat_tokenize( .. } = req; - if response_format.is_some() && tools.is_some() { - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Grammar and tools are mutually exclusive".to_string(), - error_type: "validation".to_string(), - }), - )); - } - - let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { - Ok(grammar) => grammar, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{}", err); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; - - let tools_grammar_prompt = tool_grammar.as_ref().map(|t| { - ( - GrammarType::Json(serde_json::json!(t)), - tool_prompt.unwrap_or_default(), - ) - }); - - let (tools_grammar_prompt, _grammar) = response_format - .map(|rf| (None, Some(rf))) - .unwrap_or_else(|| { - ( - tools_grammar_prompt.clone(), - tools_grammar_prompt.map(|(g, _)| g), - ) - }); - - let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { - Ok(inputs) => inputs, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{}", err); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; + let tool_prompt = tool_prompt.unwrap_or_default(); + let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; let generate_request = GenerateRequest { inputs, @@ -233,8 +187,11 @@ async fn get_chat_tokenize( .iter() .zip(encoding.get_offsets()) .map(|(&id, &(start, stop))| { - let text: String = - String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); SimpleToken { id, text, @@ -1179,63 +1136,14 @@ async fn chat_completions( Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - - // response_format and tools are mutually exclusive - if response_format.is_some() && tools.as_ref().is_some() { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Grammar and tools are mutually exclusive".to_string(), - error_type: "grammar and tools".to_string(), - }), - )); - } - - // extract tool grammar if present - let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { - Ok(grammar) => grammar, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; - - // determine the appropriate arguments for apply_chat_template - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); - - let (tools_grammar_prompt, grammar) = match response_format { - Some(response_format) => (None, Some(response_format)), - None => ( - tools_grammar_prompt.clone(), - tools_grammar_prompt.map(|(grammar, _)| grammar.clone()), - ), - }; - - // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { - Ok(inputs) => inputs, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; + let (inputs, grammar, tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; // build the request passing some parameters let generate_request = GenerateRequest { @@ -1505,8 +1413,11 @@ async fn tokenize( .iter() .zip(encoding.get_offsets()) .map(|(&id, &(start, stop))| { - let text: String = - String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); SimpleToken { id, text, @@ -2478,3 +2389,36 @@ fn create_post_processor( Ok(post_processor) } + +type PreparedInput = (String, Option, Option); + +fn prepare_chat_input( + infer: &Infer, + response_format: Option, + tools: Option>, + tool_choice: ToolChoice, + tool_prompt: &str, + messages: Vec, +) -> Result { + if response_format.is_some() && tools.is_some() { + return Err(InferError::ToolError( + "Grammar and tools are mutually exclusive".into(), + )); + } + + if let Some(format) = response_format { + let inputs = infer.apply_chat_template(messages, None)?; + return Ok((inputs, Some(format), None)); + } + + // if tools are set, apply the tool grammar and then the chat template + let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; + let grammar = tool_grammar + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + let tools_grammar_prompt = tool_grammar + .as_ref() + .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); + let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + Ok((inputs, grammar, tool_grammar)) +}