mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve grammar to include name and add tests
This commit is contained in:
parent
c38a7d7ddd
commit
4930de857d
@ -399,33 +399,23 @@ impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
|
|||||||
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
|
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
|
||||||
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
|
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
|
||||||
let (tokens, top_tokens) = value;
|
let (tokens, top_tokens) = value;
|
||||||
|
Self {
|
||||||
// Create an iterator that produces None for top_tokens once it's exhausted
|
content: tokens
|
||||||
let top_tokens_iter = top_tokens
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(Some)
|
.zip(top_tokens)
|
||||||
.chain(std::iter::repeat(None));
|
.map(|(t, top_t)| ChatCompletionLogprob {
|
||||||
|
|
||||||
let content = tokens
|
|
||||||
.into_iter()
|
|
||||||
.zip(top_tokens_iter)
|
|
||||||
.map(|(t, top_t_option)| ChatCompletionLogprob {
|
|
||||||
token: t.text,
|
token: t.text,
|
||||||
logprob: t.logprob,
|
logprob: t.logprob,
|
||||||
top_logprobs: match top_t_option {
|
top_logprobs: top_t
|
||||||
Some(top_t) => top_t
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|t| ChatCompletionTopLogprob {
|
.map(|t| ChatCompletionTopLogprob {
|
||||||
token: t.text,
|
token: t.text,
|
||||||
logprob: t.logprob,
|
logprob: t.logprob,
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
None => vec![], // Handle the case where there are no top tokens
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect(),
|
||||||
|
}
|
||||||
Self { content }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -727,26 +717,26 @@ mod deserialize_tool_choice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||||
pub struct Tools {
|
pub struct Tools {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
functions_map: FunctionsMap,
|
functions_map: FunctionsMap,
|
||||||
properties: Properties,
|
properties: Properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
struct FunctionsMap {
|
struct FunctionsMap {
|
||||||
#[serde(rename = "$functions")]
|
#[serde(rename = "$functions")]
|
||||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
struct FunctionRef {
|
struct FunctionRef {
|
||||||
#[serde(rename = "$ref")]
|
#[serde(rename = "$ref")]
|
||||||
ref_path: String,
|
ref_path: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
struct Properties {
|
struct Properties {
|
||||||
#[serde(serialize_with = "serialize_function")]
|
#[serde(serialize_with = "serialize_function")]
|
||||||
function: Vec<FunctionRef>,
|
function: Vec<FunctionRef>,
|
||||||
@ -767,7 +757,8 @@ pub(crate) struct FunctionDefinition {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub parameters: serde_json::Value,
|
#[serde(alias = "parameters")]
|
||||||
|
pub arguments: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::health::Health;
|
use crate::health::Health;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
@ -15,7 +15,7 @@ use crate::{
|
|||||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
@ -29,7 +29,6 @@ use futures::Stream;
|
|||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
@ -766,6 +765,7 @@ async fn chat_completions(
|
|||||||
let logprobs = req.logprobs.unwrap_or(false);
|
let logprobs = req.logprobs.unwrap_or(false);
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
let stop = req.stop.unwrap_or_default();
|
let stop = req.stop.unwrap_or_default();
|
||||||
|
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||||
@ -783,47 +783,22 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) {
|
||||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
Ok(grammar) => grammar,
|
||||||
let tools_to_use = match tool_choice {
|
Err(err) => {
|
||||||
ToolType::FunctionName(name) => {
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
vec![req_tools
|
tracing::error!("{err}");
|
||||||
.iter()
|
return Err((
|
||||||
.find(|tool| tool.function.name == *name)
|
|
||||||
.ok_or_else(|| {
|
|
||||||
(
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
error: "Tool choice not found in tool names".to_string(),
|
error: err.to_string(),
|
||||||
error_type: "Tool not found".to_string(),
|
error_type: err.error_type().to_string(),
|
||||||
}),
|
}),
|
||||||
)
|
));
|
||||||
})?
|
|
||||||
.clone()]
|
|
||||||
}
|
}
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let functions: HashMap<String, Value> = tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| {
|
|
||||||
let func = tool.function.clone();
|
|
||||||
(func.name, func.parameters)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = Tools {
|
|
||||||
functions_map: FunctionsMap { functions },
|
|
||||||
properties: Properties {
|
|
||||||
function: tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| FunctionRef {
|
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let grammar = if let Some(tools) = &tool_grammar {
|
||||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||||
(
|
(
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
@ -834,7 +809,7 @@ async fn chat_completions(
|
|||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
Some(GrammarType::Json(serde_json::to_value(tools).unwrap()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -860,7 +835,7 @@ async fn chat_completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar: tool_grammar.clone(),
|
grammar,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -949,21 +924,23 @@ async fn chat_completions(
|
|||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
description: None,
|
description: None,
|
||||||
name: "tools".to_string(),
|
name: gen_text_value
|
||||||
parameters: gen_text_value.get("function").map_or_else(
|
.get("function")
|
||||||
|| {
|
.and_then(|f| f.get("_name"))
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
.and_then(|name| name.as_str())
|
||||||
(
|
.unwrap_or("default_function_name")
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
.to_string(),
|
||||||
Json(ErrorResponse {
|
// Serialize the JSON object obtained from "function" to an escaped JSON string
|
||||||
error: e.to_string(),
|
arguments: gen_text_value
|
||||||
error_type: "Input validation error".to_string(),
|
.get("function")
|
||||||
}),
|
.map(|f| {
|
||||||
)
|
let mut f_cloned = f.clone();
|
||||||
|
if let Value::Object(ref mut props) = f_cloned {
|
||||||
|
props.remove("_name");
|
||||||
|
}
|
||||||
|
f_cloned
|
||||||
})
|
})
|
||||||
},
|
.unwrap_or_default(),
|
||||||
|f| Ok(f.clone()),
|
|
||||||
)?,
|
|
||||||
},
|
},
|
||||||
}];
|
}];
|
||||||
(Some(tool_calls), None)
|
(Some(tool_calls), None)
|
||||||
@ -1539,6 +1516,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
|
Loading…
Reference in New Issue
Block a user