feat: implement a templated endpoint for visibility into chat requests

This commit is contained in:
drbh 2024-07-30 13:06:52 +00:00
parent 7451041ecd
commit 62d7be3727

View File

@ -115,6 +115,87 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0) Json(info.0)
} }
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/templated",
request_body = ChatRequest,
responses((status = 200, description = "Templated Chat Request", body = Value))
)]
async fn get_templated(
Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
messages,
response_format,
tools,
tool_choice,
tool_prompt,
..
} = req;
if response_format.is_some() && tools.is_some() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Grammar and tools are mutually exclusive".to_string(),
error_type: "validation".to_string(),
}),
));
}
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{}", err);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
let tools_grammar_prompt = tool_grammar.as_ref().map(|t| {
(
GrammarType::Json(serde_json::json!(t)),
tool_prompt.unwrap_or_default(),
)
});
let (tools_grammar_prompt, _grammar) = response_format
.map(|rf| (None, Some(rf)))
.unwrap_or_else(|| {
(
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(g, _)| g),
)
});
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{}", err);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
Ok((HeaderMap::new(), Json(inputs)).into_response())
}
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
@ -2036,6 +2117,7 @@ async fn start(
} }
let info_routes = Router::new() let info_routes = Router::new()
.route("/", get(health)) .route("/", get(health))
.route("/templated", post(get_templated))
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))