fix: prefer name completion over fill in middle

This commit is contained in:
drbh 2024-02-20 14:42:31 -05:00
parent 10621a30be
commit 2d79e7744a
3 changed files with 20 additions and 16 deletions

View File

@ -1,7 +1,7 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, Entry, FillInMiddleInputs, GenerateRequest, GenerateStreamResponse, ChatTemplateInputs, Entry, CompletionTemplateInputs, GenerateRequest, GenerateStreamResponse,
HubTokenizerConfig, Message, PrefillToken, Queue, Token, HubTokenizerConfig, Message, PrefillToken, Queue, Token,
}; };
use futures::future::try_join_all; use futures::future::try_join_all;
@ -33,8 +33,8 @@ pub struct Infer {
shared: Arc<Shared>, shared: Arc<Shared>,
/// Chat template /// Chat template
chat_template: Option<ChatTemplate>, chat_template: Option<ChatTemplate>,
/// Fill in middle template /// Completion template
fill_in_middle_template: Option<FillInMiddleTemplate>, completion_template: Option<CompletionTemplate>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
} }
@ -90,9 +90,9 @@ impl Infer {
.chat_template .chat_template
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
let fill_in_middle_template = tokenizer_config let completion_template = tokenizer_config
.fill_in_middle_template .completion_template
.map(FillInMiddleTemplate::new); .map(CompletionTemplate::new);
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -102,7 +102,7 @@ impl Infer {
queue, queue,
shared, shared,
chat_template, chat_template,
fill_in_middle_template, completion_template,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
} }
} }
@ -193,15 +193,15 @@ impl Infer {
}) })
} }
/// Apply the fill in the middle template to the request /// Apply the completion template to the request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_fill_in_middle_template( pub(crate) fn apply_completion_template(
&self, &self,
prompt: String, prompt: String,
prefix: Option<String>, prefix: Option<String>,
suffix: Option<String>, suffix: Option<String>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.fill_in_middle_template self.completion_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(prompt, prefix, suffix) .apply(prompt, prefix, suffix)
@ -369,11 +369,11 @@ impl ChatTemplate {
} }
#[derive(Clone)] #[derive(Clone)]
struct FillInMiddleTemplate { struct CompletionTemplate {
template: Template<'static, 'static>, template: Template<'static, 'static>,
} }
impl FillInMiddleTemplate { impl CompletionTemplate {
fn new(template: String) -> Self { fn new(template: String) -> Self {
let mut env = Box::new(Environment::new()); let mut env = Box::new(Environment::new());
let template_str = template.into_boxed_str(); let template_str = template.into_boxed_str();
@ -393,7 +393,7 @@ impl FillInMiddleTemplate {
suffix: Option<String>, suffix: Option<String>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.template self.template
.render(FillInMiddleInputs { .render(CompletionTemplateInputs {
prefix: prefix.as_deref(), prefix: prefix.as_deref(),
prompt: prompt.as_str(), prompt: prompt.as_str(),
suffix: suffix.as_deref(), suffix: suffix.as_deref(),

View File

@ -51,7 +51,7 @@ pub struct HubModelInfo {
#[derive(Clone, Deserialize, Default)] #[derive(Clone, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
pub chat_template: Option<String>, pub chat_template: Option<String>,
pub fill_in_middle_template: Option<String>, pub completion_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
pub bos_token: Option<String>, pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
@ -617,7 +617,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub(crate) struct FillInMiddleInputs<'a> { pub(crate) struct CompletionTemplateInputs<'a> {
prefix: Option<&'a str>, prefix: Option<&'a str>,
prompt: &'a str, prompt: &'a str,
suffix: Option<&'a str>, suffix: Option<&'a str>,

View File

@ -577,7 +577,7 @@ async fn completions(
let stream = req.stream.unwrap_or_default(); let stream = req.stream.unwrap_or_default();
let seed = req.seed; let seed = req.seed;
let inputs = match infer.apply_fill_in_middle_template(req.prompt, None, req.suffix) { let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -1067,6 +1067,7 @@ pub async fn run(
generate, generate,
generate_stream, generate_stream,
chat_completions, chat_completions,
completions,
tokenize, tokenize,
metrics, metrics,
), ),
@ -1081,6 +1082,9 @@ pub async fn run(
ChatCompletionDelta, ChatCompletionDelta,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletion, ChatCompletion,
CompletionRequest,
CompletionComplete,
CompletionCompleteChunk,
GenerateParameters, GenerateParameters,
PrefillToken, PrefillToken,
Token, Token,