mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: add --grammar-support cli flag and validation error
This commit is contained in:
parent
91a114a490
commit
fe787d1361
@ -382,6 +382,11 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
|
|
||||||
|
/// Enable outlines grammar constrained generation
|
||||||
|
/// This is a feature that allows you to generate text that follows a specific grammar.
|
||||||
|
#[clap(long, env)]
|
||||||
|
grammar_support: bool,
|
||||||
|
|
||||||
/// Display a lot of information about your runtime environment
|
/// Display a lot of information about your runtime environment
|
||||||
#[clap(long, short, action)]
|
#[clap(long, short, action)]
|
||||||
env: bool,
|
env: bool,
|
||||||
@ -1051,6 +1056,11 @@ fn spawn_webserver(
|
|||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Grammar support
|
||||||
|
if args.grammar_support {
|
||||||
|
router_args.push("--grammar-support".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Tokenizer config path
|
// Tokenizer config path
|
||||||
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
||||||
router_args.push("--tokenizer-config-path".to_string());
|
router_args.push("--tokenizer-config-path".to_string());
|
||||||
|
@ -75,6 +75,8 @@ struct Args {
|
|||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
grammar_support: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
|
grammar_support,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
|
grammar_support,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -780,6 +780,7 @@ pub async fn run(
|
|||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
|
grammar_support: bool,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -841,6 +842,7 @@ pub async fn run(
|
|||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
grammar_support,
|
||||||
);
|
);
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||||
|
@ -19,6 +19,7 @@ pub struct Validation {
|
|||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
grammar_support: bool,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||||
}
|
}
|
||||||
@ -32,6 +33,7 @@ impl Validation {
|
|||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
grammar_support: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let sender = if let Some(tokenizer) = tokenizer {
|
let sender = if let Some(tokenizer) = tokenizer {
|
||||||
@ -66,6 +68,7 @@ impl Validation {
|
|||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
grammar_support,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,6 +296,11 @@ impl Validation {
|
|||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Ensure that grammar is not set if it's not supported
|
||||||
|
if !grammar.is_empty() && !self.grammar_support {
|
||||||
|
return Err(ValidationError::Grammar);
|
||||||
|
}
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
@ -455,6 +463,8 @@ pub enum ValidationError {
|
|||||||
StopSequence(usize, usize),
|
StopSequence(usize, usize),
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
|
#[error("grammar is not supported")]
|
||||||
|
Grammar,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
Loading…
Reference in New Issue
Block a user