mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: deprecate suffix and completion template
This commit is contained in:
parent
de421dc53e
commit
5f8526235a
@ -1,8 +1,8 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, CompletionTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse,
|
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
|
||||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
Message, PrefillToken, Queue, Token,
|
||||||
};
|
};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
@ -33,8 +33,6 @@ pub struct Infer {
|
|||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Chat template
|
/// Chat template
|
||||||
chat_template: Option<ChatTemplate>,
|
chat_template: Option<ChatTemplate>,
|
||||||
/// Completion template
|
|
||||||
completion_template: Option<CompletionTemplate>,
|
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
}
|
}
|
||||||
@ -90,10 +88,6 @@ impl Infer {
|
|||||||
.chat_template
|
.chat_template
|
||||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||||
|
|
||||||
let completion_template = tokenizer_config
|
|
||||||
.completion_template
|
|
||||||
.map(CompletionTemplate::new);
|
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
@ -102,7 +96,6 @@ impl Infer {
|
|||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
chat_template,
|
chat_template,
|
||||||
completion_template,
|
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,24 +186,6 @@ impl Infer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply the completion template to the request
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) fn apply_completion_template(
|
|
||||||
&self,
|
|
||||||
prompt: String,
|
|
||||||
suffix: Option<String>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
self.completion_template
|
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
|
||||||
.apply(prompt, suffix)
|
|
||||||
.map_err(|e| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
|
||||||
tracing::error!("{e}");
|
|
||||||
e
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate(
|
pub(crate) async fn generate(
|
||||||
@ -367,34 +342,6 @@ impl ChatTemplate {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct CompletionTemplate {
|
|
||||||
template: Template<'static, 'static>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionTemplate {
|
|
||||||
fn new(template: String) -> Self {
|
|
||||||
let mut env = Box::new(Environment::new());
|
|
||||||
let template_str = template.into_boxed_str();
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
|
||||||
let template = Box::leak(env)
|
|
||||||
.template_from_str(Box::leak(template_str))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Self { template }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply(&self, prompt: String, suffix: Option<String>) -> Result<String, InferError> {
|
|
||||||
self.template
|
|
||||||
.render(CompletionTemplateInputs {
|
|
||||||
prompt: prompt.as_str(),
|
|
||||||
suffix: suffix.as_deref(),
|
|
||||||
})
|
|
||||||
.map_err(InferError::TemplateError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
|
@ -618,12 +618,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
|
||||||
pub(crate) struct CompletionTemplateInputs<'a> {
|
|
||||||
prompt: &'a str,
|
|
||||||
suffix: Option<&'a str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
pub(crate) struct Message {
|
pub(crate) struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
|
@ -579,24 +579,22 @@ async fn completions(
|
|||||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
|
||||||
let inputs = match infer.apply_completion_template(req.prompt, req.suffix) {
|
// if suffix is present throw an error
|
||||||
Ok(inputs) => inputs,
|
if req.suffix.is_some() {
|
||||||
Err(err) => {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
tracing::error!("{err}");
|
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
error: err.to_string(),
|
error: "Suffix is not supported and can be achieved by preprocessing the prompt."
|
||||||
error_type: err.error_type().to_string(),
|
.to_string(),
|
||||||
|
error_type: "suffix not supported".to_string(),
|
||||||
}),
|
}),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: req.prompt.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature: req.temperature,
|
||||||
|
Loading…
Reference in New Issue
Block a user