Add a flag that enables users to get logprobs back.

This commit is contained in:
Nicolas Patry 2024-12-04 18:45:28 +01:00
parent f6998f84e9
commit 3a86afc713
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -754,6 +754,14 @@ struct Args {
/// Default is 2MB /// Default is 2MB
#[clap(default_value = "2000000", long, env)] #[clap(default_value = "2000000", long, env)]
payload_limit: usize, payload_limit: usize,
/// Enables prefill logprobs
///
/// Logprobs in the prompt are deactivated by default because they consume
/// a large amount of VRAM (especially for long prompts).
/// Using this flag reallows users to ask for them.
#[clap(long, env)]
enable_prefill_logprobs: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@ -789,6 +797,7 @@ fn shard_manager(
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_input_tokens: Option<usize>, max_input_tokens: Option<usize>,
lora_adapters: Option<String>, lora_adapters: Option<String>,
enable_prefill_logprobs: bool,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
otlp_service_name: String, otlp_service_name: String,
log_level: LevelFilter, log_level: LevelFilter,
@ -938,6 +947,11 @@ fn shard_manager(
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into())); envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
} }
// Logprobs
if enable_prefill_logprobs {
envs.push(("REQUEST_LOGPROBS".into(), "1".into()));
}
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -1429,6 +1443,7 @@ fn spawn_shards(
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size; let max_batch_size = args.max_batch_size;
let lora_adapters = args.lora_adapters.clone(); let lora_adapters = args.lora_adapters.clone();
let enable_prefill_logprobs = args.enable_prefill_logprobs;
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
@ -1456,6 +1471,7 @@ fn spawn_shards(
max_batch_size, max_batch_size,
max_input_tokens, max_input_tokens,
lora_adapters, lora_adapters,
enable_prefill_logprobs,
otlp_endpoint, otlp_endpoint,
otlp_service_name, otlp_service_name,
max_log_level, max_log_level,