mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support fill in the middle templates
This commit is contained in:
parent
0fb864ef44
commit
6a0f737cb6
@ -1,8 +1,8 @@
|
||||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
|
||||
Message, PrefillToken, Queue, Token,
|
||||
ChatTemplateInputs, Entry, FillInMiddleInputs, GenerateRequest, GenerateStreamResponse,
|
||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||
};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
@ -33,6 +33,8 @@ pub struct Infer {
|
||||
shared: Arc<Shared>,
|
||||
/// Chat template
|
||||
chat_template: Option<ChatTemplate>,
|
||||
/// Fill in middle template
|
||||
fill_in_middle_template: Option<FillInMiddleTemplate>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
}
|
||||
@ -88,6 +90,10 @@ impl Infer {
|
||||
.chat_template
|
||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||
|
||||
let fill_in_middle_template = tokenizer_config
|
||||
.fill_in_middle_template
|
||||
.map(FillInMiddleTemplate::new);
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
|
||||
@ -96,6 +102,7 @@ impl Infer {
|
||||
queue,
|
||||
shared,
|
||||
chat_template,
|
||||
fill_in_middle_template,
|
||||
limit_concurrent_requests: semaphore,
|
||||
}
|
||||
}
|
||||
@ -186,6 +193,25 @@ impl Infer {
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply the fill in the middle template to the request
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_fill_in_middle_template(
|
||||
&self,
|
||||
prompt: String,
|
||||
prefix: Option<String>,
|
||||
suffix: Option<String>,
|
||||
) -> Result<String, InferError> {
|
||||
self.fill_in_middle_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(prompt, prefix, suffix)
|
||||
.map_err(|e| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||
tracing::error!("{e}");
|
||||
e
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate(
|
||||
@ -342,6 +368,40 @@ impl ChatTemplate {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FillInMiddleTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
}
|
||||
|
||||
impl FillInMiddleTemplate {
|
||||
fn new(template: String) -> Self {
|
||||
let mut env = Box::new(Environment::new());
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
Self { template }
|
||||
}
|
||||
|
||||
fn apply(
|
||||
&self,
|
||||
prompt: String,
|
||||
prefix: Option<String>,
|
||||
suffix: Option<String>,
|
||||
) -> Result<String, InferError> {
|
||||
self.template
|
||||
.render(FillInMiddleInputs {
|
||||
prefix: prefix.as_deref(),
|
||||
prompt: prompt.as_str(),
|
||||
suffix: suffix.as_deref(),
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
|
@ -51,6 +51,7 @@ pub struct HubModelInfo {
|
||||
#[derive(Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<String>,
|
||||
pub fill_in_middle_template: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub bos_token: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
@ -615,6 +616,13 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
add_generation_prompt: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct FillInMiddleInputs<'a> {
|
||||
prefix: Option<&'a str>,
|
||||
prompt: &'a str,
|
||||
suffix: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct Message {
|
||||
#[schema(example = "user")]
|
||||
|
@ -570,15 +570,28 @@ async fn completions(
|
||||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let stream = req.stream.unwrap_or_default();
|
||||
let seed = req.seed;
|
||||
let suffix = req.suffix.unwrap_or_default();
|
||||
|
||||
let inputs = match infer.apply_fill_in_middle_template(req.prompt, None, req.suffix) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: req.prompt.to_string(),
|
||||
inputs: inputs.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
@ -685,7 +698,7 @@ async fn completions(
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: generation.generated_text + &suffix,
|
||||
text: generation.generated_text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: details.prefill.len() as u32,
|
||||
|
Loading…
Reference in New Issue
Block a user