feat: add --grammar-support cli flag and validation error

This commit is contained in:
drbh 2024-02-13 00:34:36 +00:00
parent 91a114a490
commit fe787d1361
4 changed files with 26 additions and 0 deletions

View File

@ -382,6 +382,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
tokenizer_config_path: Option<String>, tokenizer_config_path: Option<String>,
/// Enable outlines grammar constrained generation
/// This is a feature that allows you to generate text that follows a specific grammar.
#[clap(long, env)]
grammar_support: bool,
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
env: bool, env: bool,
@ -1051,6 +1056,11 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Grammar support
if args.grammar_support {
router_args.push("--grammar-support".to_string());
}
// Tokenizer config path // Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string()); router_args.push("--tokenizer-config-path".to_string());

View File

@ -75,6 +75,8 @@ struct Args {
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool, messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
grammar_support: bool,
} }
#[tokio::main] #[tokio::main]
@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled, messages_api_enabled,
grammar_support,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
messages_api_enabled, messages_api_enabled,
grammar_support,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -780,6 +780,7 @@ pub async fn run(
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -841,6 +842,7 @@ pub async fn run(
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
grammar_support,
); );
let generation_health = Arc::new(AtomicBool::new(false)); let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone()); let health_ext = Health::new(client.clone(), generation_health.clone());

View File

@ -19,6 +19,7 @@ pub struct Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
grammar_support: bool,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
} }
@ -32,6 +33,7 @@ impl Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
grammar_support: bool,
) -> Self { ) -> Self {
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let sender = if let Some(tokenizer) = tokenizer {
@ -66,6 +68,7 @@ impl Validation {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
grammar_support,
} }
} }
@ -293,6 +296,11 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
// Ensure that grammar is not set if it's not supported
if !grammar.is_empty() && !self.grammar_support {
return Err(ValidationError::Grammar);
}
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
@ -455,6 +463,8 @@ pub enum ValidationError {
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),
#[error("grammar is not supported")]
Grammar,
} }
#[cfg(test)] #[cfg(test)]