mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: simplify prepare_chat_input logic and adjust start stop chars
This commit is contained in:
parent
40658f4e84
commit
c4258e40fe
@ -23,7 +23,7 @@ use crate::{
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||
VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||
@ -144,61 +144,15 @@ async fn get_chat_tokenize(
|
||||
..
|
||||
} = req;
|
||||
|
||||
if response_format.is_some() && tools.is_some() {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||
error_type: "validation".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{}", err);
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let tools_grammar_prompt = tool_grammar.as_ref().map(|t| {
|
||||
(
|
||||
GrammarType::Json(serde_json::json!(t)),
|
||||
tool_prompt.unwrap_or_default(),
|
||||
)
|
||||
});
|
||||
|
||||
let (tools_grammar_prompt, _grammar) = response_format
|
||||
.map(|rf| (None, Some(rf)))
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
tools_grammar_prompt.clone(),
|
||||
tools_grammar_prompt.map(|(g, _)| g),
|
||||
)
|
||||
});
|
||||
|
||||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{}", err);
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
let generate_request = GenerateRequest {
|
||||
inputs,
|
||||
@ -233,8 +187,11 @@ async fn get_chat_tokenize(
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text: String =
|
||||
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
|
||||
let text = input
|
||||
.chars()
|
||||
.skip(start)
|
||||
.take(stop - start)
|
||||
.collect::<String>();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
@ -1179,63 +1136,14 @@ async fn chat_completions(
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
|
||||
// response_format and tools are mutually exclusive
|
||||
if response_format.is_some() && tools.as_ref().is_some() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||
error_type: "grammar and tools".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// determine the appropriate arguments for apply_chat_template
|
||||
let tools_grammar_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||
|
||||
let (tools_grammar_prompt, grammar) = match response_format {
|
||||
Some(response_format) => (None, Some(response_format)),
|
||||
None => (
|
||||
tools_grammar_prompt.clone(),
|
||||
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
||||
),
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
let (inputs, grammar, tool_grammar) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
@ -1505,8 +1413,11 @@ async fn tokenize(
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text: String =
|
||||
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
|
||||
let text = input
|
||||
.chars()
|
||||
.skip(start)
|
||||
.take(stop - start)
|
||||
.collect::<String>();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
@ -2478,3 +2389,36 @@ fn create_post_processor(
|
||||
|
||||
Ok(post_processor)
|
||||
}
|
||||
|
||||
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
|
||||
|
||||
fn prepare_chat_input(
|
||||
infer: &Infer,
|
||||
response_format: Option<GrammarType>,
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: ToolChoice,
|
||||
tool_prompt: &str,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<PreparedInput, InferError> {
|
||||
if response_format.is_some() && tools.is_some() {
|
||||
return Err(InferError::ToolError(
|
||||
"Grammar and tools are mutually exclusive".into(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(format) = response_format {
|
||||
let inputs = infer.apply_chat_template(messages, None)?;
|
||||
return Ok((inputs, Some(format), None));
|
||||
}
|
||||
|
||||
// if tools are set, apply the tool grammar and then the chat template
|
||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
|
||||
let grammar = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||
let tools_grammar_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
|
||||
let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?;
|
||||
Ok((inputs, grammar, tool_grammar))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user