feat: deprecate suffix and completion template

This commit is contained in:
drbh 2024-02-26 18:22:28 +00:00
parent de421dc53e
commit 5f8526235a
3 changed files with 15 additions and 76 deletions

View File

@ -1,8 +1,8 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, CompletionTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
HubTokenizerConfig, Message, PrefillToken, Queue, Token, Message, PrefillToken, Queue, Token,
}; };
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
@ -33,8 +33,6 @@ pub struct Infer {
shared: Arc<Shared>, shared: Arc<Shared>,
/// Chat template /// Chat template
chat_template: Option<ChatTemplate>, chat_template: Option<ChatTemplate>,
/// Completion template
completion_template: Option<CompletionTemplate>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
} }
@ -90,10 +88,6 @@ impl Infer {
.chat_template .chat_template
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
let completion_template = tokenizer_config
.completion_template
.map(CompletionTemplate::new);
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -102,7 +96,6 @@ impl Infer {
queue, queue,
shared, shared,
chat_template, chat_template,
completion_template,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
} }
} }
@ -193,24 +186,6 @@ impl Infer {
}) })
} }
/// Apply the completion template to the request
#[instrument(skip_all)]
pub(crate) fn apply_completion_template(
&self,
prompt: String,
suffix: Option<String>,
) -> Result<String, InferError> {
self.completion_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(prompt, suffix)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
e
})
}
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate( pub(crate) async fn generate(
@ -367,34 +342,6 @@ impl ChatTemplate {
} }
} }
#[derive(Clone)]
struct CompletionTemplate {
template: Template<'static, 'static>,
}
impl CompletionTemplate {
fn new(template: String) -> Self {
let mut env = Box::new(Environment::new());
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self { template }
}
fn apply(&self, prompt: String, suffix: Option<String>) -> Result<String, InferError> {
self.template
.render(CompletionTemplateInputs {
prompt: prompt.as_str(),
suffix: suffix.as_deref(),
})
.map_err(InferError::TemplateError)
}
}
/// Batching logic /// Batching logic
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///

View File

@ -618,12 +618,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool, add_generation_prompt: bool,
} }
#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct CompletionTemplateInputs<'a> {
prompt: &'a str,
suffix: Option<&'a str>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]

View File

@ -579,24 +579,22 @@ async fn completions(
let max_new_tokens = req.max_tokens.or(Some(100)); let max_new_tokens = req.max_tokens.or(Some(100));
let seed = req.seed; let seed = req.seed;
let inputs = match infer.apply_completion_template(req.prompt, req.suffix) { // if suffix is present throw an error
Ok(inputs) => inputs, if req.suffix.is_some() {
Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); return Err((
tracing::error!("{err}"); StatusCode::UNPROCESSABLE_ENTITY,
return Err(( Json(ErrorResponse {
StatusCode::UNPROCESSABLE_ENTITY, error: "Suffix is not supported and can be achieved by preprocessing the prompt."
Json(ErrorResponse { .to_string(),
error: err.to_string(), error_type: "suffix not supported".to_string(),
error_type: err.error_type().to_string(), }),
}), ));
)); }
}
};
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: req.prompt.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature: req.temperature,