feat: improve grammar to include name and add tests

This commit is contained in:
drbh 2024-04-02 01:28:21 +00:00
parent c38a7d7ddd
commit 4930de857d
2 changed files with 55 additions and 86 deletions

View File

@ -399,33 +399,23 @@ impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs { impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self { fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
let (tokens, top_tokens) = value; let (tokens, top_tokens) = value;
Self {
// Create an iterator that produces None for top_tokens once it's exhausted content: tokens
let top_tokens_iter = top_tokens
.into_iter() .into_iter()
.map(Some) .zip(top_tokens)
.chain(std::iter::repeat(None)); .map(|(t, top_t)| ChatCompletionLogprob {
let content = tokens
.into_iter()
.zip(top_tokens_iter)
.map(|(t, top_t_option)| ChatCompletionLogprob {
token: t.text, token: t.text,
logprob: t.logprob, logprob: t.logprob,
top_logprobs: match top_t_option { top_logprobs: top_t
Some(top_t) => top_t
.into_iter() .into_iter()
.map(|t| ChatCompletionTopLogprob { .map(|t| ChatCompletionTopLogprob {
token: t.text, token: t.text,
logprob: t.logprob, logprob: t.logprob,
}) })
.collect(), .collect(),
None => vec![], // Handle the case where there are no top tokens
},
}) })
.collect(); .collect(),
}
Self { content }
} }
} }
@ -727,26 +717,26 @@ mod deserialize_tool_choice {
} }
} }
#[derive(Debug, Deserialize, Serialize, ToSchema)] #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools { pub struct Tools {
#[serde(flatten)] #[serde(flatten)]
functions_map: FunctionsMap, functions_map: FunctionsMap,
properties: Properties, properties: Properties,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionsMap { struct FunctionsMap {
#[serde(rename = "$functions")] #[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>, functions: std::collections::HashMap<String, serde_json::Value>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionRef { struct FunctionRef {
#[serde(rename = "$ref")] #[serde(rename = "$ref")]
ref_path: String, ref_path: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Properties { struct Properties {
#[serde(serialize_with = "serialize_function")] #[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>, function: Vec<FunctionRef>,
@ -767,7 +757,8 @@ pub(crate) struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
pub name: String, pub name: String,
pub parameters: serde_json::Value, #[serde(alias = "parameters")]
pub arguments: serde_json::Value,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]

View File

@ -1,7 +1,7 @@
use crate::config::Config; use crate::config::Config;
/// HTTP Server logic /// HTTP Server logic
use crate::health::Health; use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
@ -15,7 +15,7 @@ use crate::{
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use crate::{FunctionDefinition, ToolCall, ToolType};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -29,7 +29,6 @@ use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
@ -766,6 +765,7 @@ async fn chat_completions(
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed; let seed = req.seed;
let stop = req.stop.unwrap_or_default(); let stop = req.stop.unwrap_or_default();
let tool_prompt = req.tool_prompt.unwrap_or_default();
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) { let mut inputs = match infer.apply_chat_template(req.messages) {
@ -783,47 +783,22 @@ async fn chat_completions(
} }
}; };
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) {
let tool_prompt = req.tool_prompt.unwrap_or_default(); Ok(grammar) => grammar,
let tools_to_use = match tool_choice { Err(err) => {
ToolType::FunctionName(name) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation");
vec![req_tools tracing::error!("{err}");
.iter() return Err((
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(), error: err.to_string(),
error_type: "Tool not found".to_string(), error_type: err.error_type().to_string(),
}), }),
) ));
})?
.clone()]
} }
ToolType::OneOf => req_tools.to_owned(),
};
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
}; };
let grammar = if let Some(tools) = &tool_grammar {
let tools_str = serde_json::to_string(&tools).map_err(|e| { let tools_str = serde_json::to_string(&tools).map_err(|e| {
( (
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -834,7 +809,7 @@ async fn chat_completions(
) )
})?; })?;
inputs = format!("{inputs}{tool_prompt}{tools_str}"); inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools))) Some(GrammarType::Json(serde_json::to_value(tools).unwrap()))
} else { } else {
None None
}; };
@ -860,7 +835,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar: tool_grammar.clone(), grammar,
}, },
}; };
@ -949,21 +924,23 @@ async fn chat_completions(
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
description: None, description: None,
name: "tools".to_string(), name: gen_text_value
parameters: gen_text_value.get("function").map_or_else( .get("function")
|| { .and_then(|f| f.get("_name"))
serde_json::from_str(&generation.generated_text).map_err(|e| { .and_then(|name| name.as_str())
( .unwrap_or("default_function_name")
StatusCode::UNPROCESSABLE_ENTITY, .to_string(),
Json(ErrorResponse { // Serialize the JSON object obtained from "function" to an escaped JSON string
error: e.to_string(), arguments: gen_text_value
error_type: "Input validation error".to_string(), .get("function")
}), .map(|f| {
) let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("_name");
}
f_cloned
}) })
}, .unwrap_or_default(),
|f| Ok(f.clone()),
)?,
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
@ -1539,6 +1516,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
}; };
( (