mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: prefer name completion over fill in middle
This commit is contained in:
parent
10621a30be
commit
2d79e7744a
@ -1,7 +1,7 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, Entry, FillInMiddleInputs, GenerateRequest, GenerateStreamResponse,
|
ChatTemplateInputs, Entry, CompletionTemplateInputs, GenerateRequest, GenerateStreamResponse,
|
||||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||||
};
|
};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
@ -33,8 +33,8 @@ pub struct Infer {
|
|||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Chat template
|
/// Chat template
|
||||||
chat_template: Option<ChatTemplate>,
|
chat_template: Option<ChatTemplate>,
|
||||||
/// Fill in middle template
|
/// Completion template
|
||||||
fill_in_middle_template: Option<FillInMiddleTemplate>,
|
completion_template: Option<CompletionTemplate>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
}
|
}
|
||||||
@ -90,9 +90,9 @@ impl Infer {
|
|||||||
.chat_template
|
.chat_template
|
||||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||||
|
|
||||||
let fill_in_middle_template = tokenizer_config
|
let completion_template = tokenizer_config
|
||||||
.fill_in_middle_template
|
.completion_template
|
||||||
.map(FillInMiddleTemplate::new);
|
.map(CompletionTemplate::new);
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
@ -102,7 +102,7 @@ impl Infer {
|
|||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
chat_template,
|
chat_template,
|
||||||
fill_in_middle_template,
|
completion_template,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,15 +193,15 @@ impl Infer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply the fill in the middle template to the request
|
/// Apply the completion template to the request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_fill_in_middle_template(
|
pub(crate) fn apply_completion_template(
|
||||||
&self,
|
&self,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
prefix: Option<String>,
|
prefix: Option<String>,
|
||||||
suffix: Option<String>,
|
suffix: Option<String>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.fill_in_middle_template
|
self.completion_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(prompt, prefix, suffix)
|
.apply(prompt, prefix, suffix)
|
||||||
@ -369,11 +369,11 @@ impl ChatTemplate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct FillInMiddleTemplate {
|
struct CompletionTemplate {
|
||||||
template: Template<'static, 'static>,
|
template: Template<'static, 'static>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FillInMiddleTemplate {
|
impl CompletionTemplate {
|
||||||
fn new(template: String) -> Self {
|
fn new(template: String) -> Self {
|
||||||
let mut env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
@ -393,7 +393,7 @@ impl FillInMiddleTemplate {
|
|||||||
suffix: Option<String>,
|
suffix: Option<String>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.template
|
self.template
|
||||||
.render(FillInMiddleInputs {
|
.render(CompletionTemplateInputs {
|
||||||
prefix: prefix.as_deref(),
|
prefix: prefix.as_deref(),
|
||||||
prompt: prompt.as_str(),
|
prompt: prompt.as_str(),
|
||||||
suffix: suffix.as_deref(),
|
suffix: suffix.as_deref(),
|
||||||
|
@ -51,7 +51,7 @@ pub struct HubModelInfo {
|
|||||||
#[derive(Clone, Deserialize, Default)]
|
#[derive(Clone, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
pub fill_in_middle_template: Option<String>,
|
pub completion_template: Option<String>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
pub bos_token: Option<String>,
|
pub bos_token: Option<String>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
@ -617,7 +617,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub(crate) struct FillInMiddleInputs<'a> {
|
pub(crate) struct CompletionTemplateInputs<'a> {
|
||||||
prefix: Option<&'a str>,
|
prefix: Option<&'a str>,
|
||||||
prompt: &'a str,
|
prompt: &'a str,
|
||||||
suffix: Option<&'a str>,
|
suffix: Option<&'a str>,
|
||||||
|
@ -577,7 +577,7 @@ async fn completions(
|
|||||||
let stream = req.stream.unwrap_or_default();
|
let stream = req.stream.unwrap_or_default();
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
|
||||||
let inputs = match infer.apply_fill_in_middle_template(req.prompt, None, req.suffix) {
|
let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
@ -1067,6 +1067,7 @@ pub async fn run(
|
|||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
chat_completions,
|
chat_completions,
|
||||||
|
completions,
|
||||||
tokenize,
|
tokenize,
|
||||||
metrics,
|
metrics,
|
||||||
),
|
),
|
||||||
@ -1081,6 +1082,9 @@ pub async fn run(
|
|||||||
ChatCompletionDelta,
|
ChatCompletionDelta,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionComplete,
|
||||||
|
CompletionCompleteChunk,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
|
Loading…
Reference in New Issue
Block a user