feat: update default prompt and other small refactors

This commit is contained in:
drbh 2024-04-04 01:10:06 +00:00
parent 106c9ce8e5
commit bb73acc1a9
2 changed files with 8 additions and 7 deletions

View File

@ -669,7 +669,7 @@ pub(crate) struct ChatRequest {
#[serde(default = "default_tool_prompt")] #[serde(default = "default_tool_prompt")]
#[schema( #[schema(
nullable = true, nullable = true,
example = "\"Based on the conversation, please choose the most appropriate tool to use: \"" example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
)] )]
pub tool_prompt: Option<String>, pub tool_prompt: Option<String>,

View File

@ -767,9 +767,9 @@ async fn chat_completions(
let stop = req.stop.unwrap_or_default(); let stop = req.stop.unwrap_or_default();
let tool_prompt = req.tool_prompt.unwrap_or_default(); let tool_prompt = req.tool_prompt.unwrap_or_default();
// apply chat template to flatten the request into a single input // extract tool grammar if present
let mut inputs = match infer.apply_chat_template(req.messages) { let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) {
Ok(inputs) => inputs, Ok(grammar) => grammar,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
@ -783,8 +783,9 @@ async fn chat_completions(
} }
}; };
let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) { // apply chat template to flatten the request into a single input
Ok(grammar) => grammar, let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
@ -809,7 +810,7 @@ async fn chat_completions(
) )
})?; })?;
inputs = format!("{inputs}{tool_prompt}{tools_str}"); inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::to_value(tools).unwrap())) Some(GrammarType::Json(serde_json::json!(tools)))
} else { } else {
None None
}; };