Update args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-05 10:10:38 +00:00
parent e007529590
commit d3a772a8dd
No known key found for this signature in database

View File

@ -20,7 +20,7 @@ struct Args {
#[clap(default_value = "main", long, env)] #[clap(default_value = "main", long, env)]
revision: String, revision: String,
/// Path to the GGUF model file to be used for inference. /// Path to the GGUF model file for inference.
#[clap(long, env)] #[clap(long, env)]
model_gguf: String, // TODO Option() with hf->gguf & quantize model_gguf: String, // TODO Option() with hf->gguf & quantize
@ -48,15 +48,15 @@ struct Args {
#[clap(default_value = "-1.0", long, env)] #[clap(default_value = "-1.0", long, env)]
defrag_threshold: f32, defrag_threshold: f32,
/// Setup NUMA optimizations. /// Enable NUMA optimizations.
#[clap(default_value = "disabled", value_enum, long, env)] #[clap(default_value = "disabled", value_enum, long, env)]
numa: LlamacppNuma, numa: LlamacppNuma,
/// Whether to use memory mapping. /// Use memory mapping for the model.
#[clap(default_value = "true", long, env)] #[clap(default_value = "true", long, env)]
use_mmap: bool, use_mmap: bool,
/// Whether to use memory locking. /// Use memory locking to prevent swapping.
#[clap(default_value = "false", long, env)] #[clap(default_value = "false", long, env)]
use_mlock: bool, use_mlock: bool,
@ -68,95 +68,95 @@ struct Args {
#[clap(default_value = "true", long, env)] #[clap(default_value = "true", long, env)]
flash_attention: bool, flash_attention: bool,
/// Use data type for K cache. /// Data type used for K cache.
#[clap(default_value = "f16", value_enum, long, env)] #[clap(default_value = "f16", value_enum, long, env)]
type_k: LlamacppGGMLType, type_k: LlamacppGGMLType,
/// Use data type for V cache. /// Data type used for V cache.
#[clap(default_value = "f16", value_enum, long, env)] #[clap(default_value = "f16", value_enum, long, env)]
type_v: LlamacppGGMLType, type_v: LlamacppGGMLType,
/// TODO /// Number of tokenizer workers used for payload validation and truncation.
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
/// Maximum amount of concurrent requests.
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
/// Maximum number of input tokens allowed per request. /// Maximum number of input tokens per request.
#[clap(default_value = "1024", long, env)] #[clap(default_value = "1024", long, env)]
max_input_tokens: usize, max_input_tokens: usize,
/// Maximum total tokens (input + output) allowed per request. /// Maximum total tokens (input + output) per request.
#[clap(default_value = "2048", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
// #[clap(default_value = "1.2", long, env)] /// Maximum number of tokens in a batch.
// waiting_served_ratio: f32,
// #[clap(default_value = "4096", long, env)]
// max_batch_prefill_tokens: u32,
/// Maximum number of tokens that can be submitted within a batch
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
max_batch_total_tokens: usize, max_batch_total_tokens: usize,
/// Maximum number of tokens within a batch /// Maximum number of tokens in a physical batch.
#[clap(long, env)] #[clap(long, env)]
max_physical_batch_total_tokens: Option<usize>, max_physical_batch_total_tokens: Option<usize>,
// #[clap(default_value = "20", long, env)] /// Maximum number of requests per batch.
// max_waiting_tokens: usize,
/// Maximum number of requests per batch
#[clap(default_value = "1", long, env)] #[clap(default_value = "1", long, env)]
max_batch_size: usize, max_batch_size: usize,
/// The IP address to listen on /// IP address to listen on.
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
hostname: String, hostname: String,
/// The port to listen on. /// Port to listen on.
#[clap(default_value = "3001", long, short, env)] #[clap(default_value = "3001", long, short, env)]
port: u16, port: u16,
// #[clap(default_value = "/tmp/text-generation-server-0", long, env)] /// Enable JSON output format.
// master_shard_uds_path: String,
// #[clap(long, env)]
// tokenizer_name: String,
// #[clap(long, env)]
// tokenizer_config_path: Option<String>,
// #[clap(long, env, value_enum)]
// trust_remote_code: bool,
// #[clap(long, env)]
// api_key: Option<String>,
#[clap(long, env)] #[clap(long, env)]
json_output: bool, json_output: bool,
/// OTLP endpoint for telemetry data.
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
/// Service name for OTLP telemetry.
#[clap(default_value = "text-generation-inference.router", long, env)] #[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String, otlp_service_name: String,
/// Allowed origins for CORS.
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Option<Vec<String>>, cors_allow_origin: Option<Vec<String>>,
/// Enable Ngrok tunneling.
#[clap(long, env)] #[clap(long, env)]
ngrok: bool, ngrok: bool,
/// Ngrok authentication token.
#[clap(long, env)] #[clap(long, env)]
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
/// Ngrok edge to use for tunneling.
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
/// Path to the tokenizer configuration file.
#[clap(long, env)] #[clap(long, env)]
tokenizer_config_path: Option<String>, tokenizer_config_path: Option<String>,
/// Disable grammar support.
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
/// Maximum number of inputs per request.
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
/// Level of usage statistics collection.
#[clap(default_value = "on", long, env)] #[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel, usage_stats: usage_stats::UsageStatsLevel,
/// Maximum payload size limit in bytes.
#[clap(default_value = "2000000", long, env)] #[clap(default_value = "2000000", long, env)]
payload_limit: usize, payload_limit: usize,
} }
@ -257,9 +257,9 @@ async fn main() -> Result<(), RouterError> {
server::run( server::run(
backend, backend,
args.max_concurrent_requests, args.max_concurrent_requests,
args.max_best_of, 0, // max_best_of
args.max_stop_sequences, 0, // max_stop_sequences
args.max_top_n_tokens, 0, // max_top_n_tokens
args.max_input_tokens, args.max_input_tokens,
args.max_total_tokens, args.max_total_tokens,
args.validation_workers, args.validation_workers,