mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve grammar to include name and add tests
This commit is contained in:
parent
c38a7d7ddd
commit
4930de857d
@ -399,33 +399,23 @@ impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
|
||||
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
|
||||
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
|
||||
let (tokens, top_tokens) = value;
|
||||
|
||||
// Create an iterator that produces None for top_tokens once it's exhausted
|
||||
let top_tokens_iter = top_tokens
|
||||
Self {
|
||||
content: tokens
|
||||
.into_iter()
|
||||
.map(Some)
|
||||
.chain(std::iter::repeat(None));
|
||||
|
||||
let content = tokens
|
||||
.into_iter()
|
||||
.zip(top_tokens_iter)
|
||||
.map(|(t, top_t_option)| ChatCompletionLogprob {
|
||||
.zip(top_tokens)
|
||||
.map(|(t, top_t)| ChatCompletionLogprob {
|
||||
token: t.text,
|
||||
logprob: t.logprob,
|
||||
top_logprobs: match top_t_option {
|
||||
Some(top_t) => top_t
|
||||
top_logprobs: top_t
|
||||
.into_iter()
|
||||
.map(|t| ChatCompletionTopLogprob {
|
||||
token: t.text,
|
||||
logprob: t.logprob,
|
||||
})
|
||||
.collect(),
|
||||
None => vec![], // Handle the case where there are no top tokens
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self { content }
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -727,26 +717,26 @@ mod deserialize_tool_choice {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||
pub struct Tools {
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct FunctionsMap {
|
||||
#[serde(rename = "$functions")]
|
||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct FunctionRef {
|
||||
#[serde(rename = "$ref")]
|
||||
ref_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct Properties {
|
||||
#[serde(serialize_with = "serialize_function")]
|
||||
function: Vec<FunctionRef>,
|
||||
@ -767,7 +757,8 @@ pub(crate) struct FunctionDefinition {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
#[serde(alias = "parameters")]
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::config::Config;
|
||||
/// HTTP Server logic
|
||||
use crate::health::Health;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
@ -15,7 +15,7 @@ use crate::{
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
@ -29,7 +29,6 @@ use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@ -766,6 +765,7 @@ async fn chat_completions(
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
let stop = req.stop.unwrap_or_default();
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
@ -783,47 +783,22 @@ async fn chat_completions(
|
||||
}
|
||||
};
|
||||
|
||||
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Tool choice not found in tool names".to_string(),
|
||||
error_type: "Tool not found".to_string(),
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
)
|
||||
})?
|
||||
.clone()]
|
||||
));
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
(func.name, func.parameters)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
let grammar = if let Some(tools) = &tool_grammar {
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
@ -834,7 +809,7 @@ async fn chat_completions(
|
||||
)
|
||||
})?;
|
||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||
Some(GrammarType::Json(serde_json::to_value(tools).unwrap()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -860,7 +835,7 @@ async fn chat_completions(
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar: tool_grammar.clone(),
|
||||
grammar,
|
||||
},
|
||||
};
|
||||
|
||||
@ -949,21 +924,23 @@ async fn chat_completions(
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: "tools".to_string(),
|
||||
parameters: gen_text_value.get("function").map_or_else(
|
||||
|| {
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
name: gen_text_value
|
||||
.get("function")
|
||||
.and_then(|f| f.get("_name"))
|
||||
.and_then(|name| name.as_str())
|
||||
.unwrap_or("default_function_name")
|
||||
.to_string(),
|
||||
// Serialize the JSON object obtained from "function" to an escaped JSON string
|
||||
arguments: gen_text_value
|
||||
.get("function")
|
||||
.map(|f| {
|
||||
let mut f_cloned = f.clone();
|
||||
if let Value::Object(ref mut props) = f_cloned {
|
||||
props.remove("_name");
|
||||
}
|
||||
f_cloned
|
||||
})
|
||||
},
|
||||
|f| Ok(f.clone()),
|
||||
)?,
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
@ -1539,6 +1516,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
};
|
||||
|
||||
(
|
||||
|
Loading…
Reference in New Issue
Block a user