diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8367ef81..de2b3f64 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -382,6 +382,11 @@ struct Args { #[clap(long, env)] tokenizer_config_path: Option, + /// Enable outlines grammar constrained generation + /// This is a feature that allows you to generate text that follows a specific grammar. + #[clap(long, env)] + grammar_support: bool, + /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, @@ -1051,6 +1056,11 @@ fn spawn_webserver( args.model_id, ]; + // Grammar support + if args.grammar_support { + router_args.push("--grammar-support".to_string()); + } + // Tokenizer config path if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { router_args.push("--tokenizer-config-path".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index a1f8d97b..6bd86d58 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -75,6 +75,8 @@ struct Args { ngrok_edge: Option, #[clap(long, env, default_value_t = false)] messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + grammar_support: bool, } #[tokio::main] @@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, messages_api_enabled, + grammar_support, } = args; // Launch Tokio runtime @@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, tokenizer_config, messages_api_enabled, + grammar_support, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 6e042c4e..bcf17f46 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -780,6 +780,7 @@ pub async fn run( ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, messages_api_enabled: bool, + grammar_support: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -841,6 +842,7 @@ pub async fn run( max_top_n_tokens, max_input_length, max_total_tokens, + grammar_support, ); let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); diff --git a/router/src/validation.rs b/router/src/validation.rs index a77995df..c97e878a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -19,6 +19,7 @@ pub struct Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + grammar_support: bool, /// Channel to communicate with the background tokenization task sender: Option>, } @@ -32,6 +33,7 @@ impl Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + grammar_support: bool, ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { @@ -66,6 +68,7 @@ impl Validation { max_top_n_tokens, max_input_length, max_total_tokens, + grammar_support, } } @@ -293,6 +296,11 @@ impl Validation { .validate_input(request.inputs, truncate, max_new_tokens) .await?; + // Ensure that grammar is not set if it's not supported + if !grammar.is_empty() && !self.grammar_support { + return Err(ValidationError::Grammar); + } + let parameters = NextTokenChooserParameters { temperature, repetition_penalty, @@ -455,6 +463,8 @@ pub enum ValidationError { StopSequence(usize, usize), #[error("tokenizer error {0}")] Tokenizer(String), + #[error("grammar is not supported")] + Grammar, } #[cfg(test)]