From 3a86afc713122081799fa0aa7a2f984fee68410b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 4 Dec 2024 18:45:28 +0100 Subject: [PATCH] Add a flag that enables users to get logprobs back. --- launcher/src/main.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f3f31f66..610d6227 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -754,6 +754,14 @@ struct Args { /// Default is 2MB #[clap(default_value = "2000000", long, env)] payload_limit: usize, + + /// Enables prefill logprobs + /// + /// Logprobs in the prompt are deactivated by default because they consume + /// a large amount of VRAM (especially for long prompts). + /// Using this flag reallows users to ask for them. + #[clap(long, env)] + enable_prefill_logprobs: bool, } #[derive(Debug)] @@ -789,6 +797,7 @@ fn shard_manager( max_batch_size: Option, max_input_tokens: Option, lora_adapters: Option, + enable_prefill_logprobs: bool, otlp_endpoint: Option, otlp_service_name: String, log_level: LevelFilter, @@ -938,6 +947,11 @@ fn shard_manager( envs.push(("LORA_ADAPTERS".into(), lora_adapters.into())); } + // Logprobs + if enable_prefill_logprobs { + envs.push(("REQUEST_LOGPROBS".into(), "1".into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -1429,6 +1443,7 @@ fn spawn_shards( let rope_factor = args.rope_factor; let max_batch_size = args.max_batch_size; let lora_adapters = args.lora_adapters.clone(); + let enable_prefill_logprobs = args.enable_prefill_logprobs; thread::spawn(move || { shard_manager( model_id, @@ -1456,6 +1471,7 @@ fn spawn_shards( max_batch_size, max_input_tokens, lora_adapters, + enable_prefill_logprobs, otlp_endpoint, otlp_service_name, max_log_level,