mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: add --grammar-support cli flag and validation error
This commit is contained in:
parent
91a114a490
commit
fe787d1361
@ -382,6 +382,11 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
|
||||
/// Enable outlines grammar constrained generation
|
||||
/// This is a feature that allows you to generate text that follows a specific grammar.
|
||||
#[clap(long, env)]
|
||||
grammar_support: bool,
|
||||
|
||||
/// Display a lot of information about your runtime environment
|
||||
#[clap(long, short, action)]
|
||||
env: bool,
|
||||
@ -1051,6 +1056,11 @@ fn spawn_webserver(
|
||||
args.model_id,
|
||||
];
|
||||
|
||||
// Grammar support
|
||||
if args.grammar_support {
|
||||
router_args.push("--grammar-support".to_string());
|
||||
}
|
||||
|
||||
// Tokenizer config path
|
||||
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
||||
router_args.push("--tokenizer-config-path".to_string());
|
||||
|
@ -75,6 +75,8 @@ struct Args {
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
grammar_support: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
grammar_support,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
ngrok_edge,
|
||||
tokenizer_config,
|
||||
messages_api_enabled,
|
||||
grammar_support,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -780,6 +780,7 @@ pub async fn run(
|
||||
ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -841,6 +842,7 @@ pub async fn run(
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
grammar_support,
|
||||
);
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||
|
@ -19,6 +19,7 @@ pub struct Validation {
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
grammar_support: bool,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||
}
|
||||
@ -32,6 +33,7 @@ impl Validation {
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
grammar_support: bool,
|
||||
) -> Self {
|
||||
// If we have a fast tokenizer
|
||||
let sender = if let Some(tokenizer) = tokenizer {
|
||||
@ -66,6 +68,7 @@ impl Validation {
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
grammar_support,
|
||||
}
|
||||
}
|
||||
|
||||
@ -293,6 +296,11 @@ impl Validation {
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
// Ensure that grammar is not set if it's not supported
|
||||
if !grammar.is_empty() && !self.grammar_support {
|
||||
return Err(ValidationError::Grammar);
|
||||
}
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
@ -455,6 +463,8 @@ pub enum ValidationError {
|
||||
StopSequence(usize, usize),
|
||||
#[error("tokenizer error {0}")]
|
||||
Tokenizer(String),
|
||||
#[error("grammar is not supported")]
|
||||
Grammar,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
Loading…
Reference in New Issue
Block a user