feat: improve grammar to include name and add tests

This commit is contained in:
drbh 2024-04-02 01:28:21 +00:00
parent c38a7d7ddd
commit 4930de857d
2 changed files with 55 additions and 86 deletions

View File

@ -399,33 +399,23 @@ impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
let (tokens, top_tokens) = value;
// Create an iterator that produces None for top_tokens once it's exhausted
let top_tokens_iter = top_tokens
.into_iter()
.map(Some)
.chain(std::iter::repeat(None));
let content = tokens
.into_iter()
.zip(top_tokens_iter)
.map(|(t, top_t_option)| ChatCompletionLogprob {
token: t.text,
logprob: t.logprob,
top_logprobs: match top_t_option {
Some(top_t) => top_t
Self {
content: tokens
.into_iter()
.zip(top_tokens)
.map(|(t, top_t)| ChatCompletionLogprob {
token: t.text,
logprob: t.logprob,
top_logprobs: top_t
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
None => vec![], // Handle the case where there are no top tokens
},
})
.collect();
Self { content }
})
.collect(),
}
}
}
@ -727,26 +717,26 @@ mod deserialize_tool_choice {
}
}
#[derive(Debug, Deserialize, Serialize, ToSchema)]
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools {
#[serde(flatten)]
functions_map: FunctionsMap,
properties: Properties,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionsMap {
#[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionRef {
#[serde(rename = "$ref")]
ref_path: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Properties {
#[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>,
@ -767,7 +757,8 @@ pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
pub parameters: serde_json::Value,
#[serde(alias = "parameters")]
pub arguments: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]

View File

@ -1,7 +1,7 @@
use crate::config::Config;
/// HTTP Server logic
use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
@ -15,7 +15,7 @@ use crate::{
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
};
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use crate::{FunctionDefinition, ToolCall, ToolType};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
@ -29,7 +29,6 @@ use futures::Stream;
use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
@ -766,6 +765,7 @@ async fn chat_completions(
let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed;
let stop = req.stop.unwrap_or_default();
let tool_prompt = req.tool_prompt.unwrap_or_default();
// apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) {
@ -783,47 +783,22 @@ async fn chat_completions(
}
};
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
let tool_prompt = req.tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(),
error_type: "Tool not found".to_string(),
}),
)
})?
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};
let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) {
Ok(grammar) => grammar,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
let grammar = if let Some(tools) = &tool_grammar {
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
@ -834,7 +809,7 @@ async fn chat_completions(
)
})?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools)))
Some(GrammarType::Json(serde_json::to_value(tools).unwrap()))
} else {
None
};
@ -860,7 +835,7 @@ async fn chat_completions(
decoder_input_details: !stream,
seed,
top_n_tokens: req.top_logprobs,
grammar: tool_grammar.clone(),
grammar,
},
};
@ -949,21 +924,23 @@ async fn chat_completions(
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name: "tools".to_string(),
parameters: gen_text_value.get("function").map_or_else(
|| {
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})
},
|f| Ok(f.clone()),
)?,
name: gen_text_value
.get("function")
.and_then(|f| f.get("_name"))
.and_then(|name| name.as_str())
.unwrap_or("default_function_name")
.to_string(),
// Serialize the JSON object obtained from "function" to an escaped JSON string
arguments: gen_text_value
.get("function")
.map(|f| {
let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("_name");
}
f_cloned
})
.unwrap_or_default(),
},
}];
(Some(tool_calls), None)
@ -1539,6 +1516,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
};
(