feat: simplify prepare_chat_input logic and adjust start stop chars

This commit is contained in:
drbh 2024-08-02 23:35:28 +00:00
parent 40658f4e84
commit c4258e40fe

View File

@ -23,7 +23,7 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
@ -144,61 +144,15 @@ async fn get_chat_tokenize(
.. ..
} = req; } = req;
if response_format.is_some() && tools.is_some() { let tool_prompt = tool_prompt.unwrap_or_default();
return Err(( let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
StatusCode::UNPROCESSABLE_ENTITY, &infer,
Json(ErrorResponse { response_format,
error: "Grammar and tools are mutually exclusive".to_string(), tools,
error_type: "validation".to_string(), tool_choice,
}), &tool_prompt,
)); messages,
} )?;
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{}", err);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
let tools_grammar_prompt = tool_grammar.as_ref().map(|t| {
(
GrammarType::Json(serde_json::json!(t)),
tool_prompt.unwrap_or_default(),
)
});
let (tools_grammar_prompt, _grammar) = response_format
.map(|rf| (None, Some(rf)))
.unwrap_or_else(|| {
(
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(g, _)| g),
)
});
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{}", err);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs, inputs,
@ -233,8 +187,11 @@ async fn get_chat_tokenize(
.iter() .iter()
.zip(encoding.get_offsets()) .zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| { .map(|(&id, &(start, stop))| {
let text: String = let text = input
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); .chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken { SimpleToken {
id, id,
text, text,
@ -1179,63 +1136,14 @@ async fn chat_completions(
Some(temperature) if temperature == 0.0 => (false, None), Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other), other => (true, other),
}; };
let (inputs, grammar, tool_grammar) = prepare_chat_input(
// response_format and tools are mutually exclusive &infer,
if response_format.is_some() && tools.as_ref().is_some() { response_format,
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tools,
return Err(( tool_choice,
StatusCode::UNPROCESSABLE_ENTITY, &tool_prompt,
Json(ErrorResponse { messages,
error: "Grammar and tools are mutually exclusive".to_string(), )?;
error_type: "grammar and tools".to_string(),
}),
));
}
// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
// determine the appropriate arguments for apply_chat_template
let tools_grammar_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
let (tools_grammar_prompt, grammar) = match response_format {
Some(response_format) => (None, Some(response_format)),
None => (
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
),
};
// apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
@ -1505,8 +1413,11 @@ async fn tokenize(
.iter() .iter()
.zip(encoding.get_offsets()) .zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| { .map(|(&id, &(start, stop))| {
let text: String = let text = input
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); .chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken { SimpleToken {
id, id,
text, text,
@ -2478,3 +2389,36 @@ fn create_post_processor(
Ok(post_processor) Ok(post_processor)
} }
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
fn prepare_chat_input(
infer: &Infer,
response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
tool_prompt: &str,
messages: Vec<Message>,
) -> Result<PreparedInput, InferError> {
if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
"Grammar and tools are mutually exclusive".into(),
));
}
if let Some(format) = response_format {
let inputs = infer.apply_chat_template(messages, None)?;
return Ok((inputs, Some(format), None));
}
// if tools are set, apply the tool grammar and then the chat template
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
let grammar = tool_grammar
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));
let tools_grammar_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?;
Ok((inputs, grammar, tool_grammar))
}