mod backend; use backend::{LlamacppConfig, LlamacppBackend, BackendError}; use clap::{Parser}; use text_generation_router::{logging, server, usage_stats}; use thiserror::Error; use tokenizers::{Tokenizer, FromPretrainedParameters}; use tokio::sync::oneshot::error::RecvError; use tracing::error; /// Backend Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { /// Name of the model to load. #[clap(long, env)] model_id: String, /// Revision of the model. #[clap(default_value = "main", long, env)] revision: String, /// Path to the GGUF model file to be used for inference. #[clap(long, env)] model_gguf: String, // TODO Option() with hf->gguf & quantize /// Context size for the model. #[clap(default_value = "4096", long, env)] n_ctx: u32, /// Number of threads to use for inference. #[clap(default_value = "1", long, env)] n_threads: i32, #[clap(default_value = "true", long, env)] /// Whether to use memory mapping. use_mmap: bool, #[clap(default_value = "false", long, env)] /// Whether to use memory locking. use_mlock: bool, /// Enable flash attention for faster inference. (EXPERIMENTAL) #[clap(default_value = "false", long, env)] flash_attention: bool, /// TODO #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, #[clap(default_value = "2", long, env)] max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, #[clap(default_value = "5", long, env)] max_top_n_tokens: u32, /// Maximum number of input tokens allowed per request. #[clap(default_value = "1024", long, env)] max_input_tokens: usize, /// Maximum total tokens (input + output) allowed per request. #[clap(default_value = "2048", long, env)] max_total_tokens: usize, // #[clap(default_value = "1.2", long, env)] // waiting_served_ratio: f32, // #[clap(default_value = "4096", long, env)] // max_batch_prefill_tokens: u32, // #[clap(long, env)] // max_batch_total_tokens: Option, // #[clap(default_value = "20", long, env)] // max_waiting_tokens: usize, // #[clap(long, env)] // max_batch_size: Option, /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] hostname: String, /// The port to listen on. #[clap(default_value = "3001", long, short, env)] port: u16, // #[clap(default_value = "/tmp/text-generation-server-0", long, env)] // master_shard_uds_path: String, // #[clap(long, env)] // tokenizer_name: String, // #[clap(long, env)] // tokenizer_config_path: Option, // #[clap(long, env, value_enum)] // trust_remote_code: bool, // #[clap(long, env)] // api_key: Option, #[clap(long, env)] json_output: bool, #[clap(long, env)] otlp_endpoint: Option, #[clap(default_value = "text-generation-inference.router", long, env)] otlp_service_name: String, #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] ngrok: bool, #[clap(long, env)] ngrok_authtoken: Option, #[clap(long, env)] ngrok_edge: Option, #[clap(long, env)] tokenizer_config_path: Option, #[clap(long, env, default_value_t = false)] disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, #[clap(default_value = "on", long, env)] usage_stats: usage_stats::UsageStatsLevel, #[clap(default_value = "2000000", long, env)] payload_limit: usize, } #[tokio::main] async fn main() -> Result<(), RouterError> { let args = Args::parse(); logging::init_logging( args.otlp_endpoint, args.otlp_service_name, args.json_output ); if args.max_input_tokens >= args.max_total_tokens { return Err(RouterError::ArgumentValidation( "`max_input_tokens` must be < `max_total_tokens`".to_string(), )); } // TODO: check if we use the same cache of Server // check if llamacpp is faster let tokenizer = { let token = std::env::var("HF_TOKEN") .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) .ok(); let params = FromPretrainedParameters { revision: args.revision.clone(), token: token, ..Default::default() }; Tokenizer::from_pretrained( args.model_id.clone(), Some(params) )? }; let (backend, ok) = LlamacppBackend::new( LlamacppConfig { model_gguf: args.model_gguf, n_ctx: args.n_ctx, n_threads: args.n_threads, use_mmap: args.use_mmap, use_mlock: args.use_mlock, flash_attention: args.flash_attention, }, tokenizer, ); ok.await??; server::run( backend, args.max_concurrent_requests, args.max_best_of, args.max_stop_sequences, args.max_top_n_tokens, args.max_input_tokens, args.max_total_tokens, args.validation_workers, None, // api_key args.model_id, // tokenizer_name args.tokenizer_config_path, Some(args.revision), false, // trust_remote_code args.hostname, args.port, args.cors_allow_origin, args.ngrok, args.ngrok_authtoken, args.ngrok_edge, args.disable_grammar_support, args.max_client_batch_size, args.usage_stats, args.payload_limit, ) .await?; Ok(()) } #[derive(Debug, Error)] enum RouterError { #[error("Argument validation error: {0}")] ArgumentValidation(String), #[error("Tokenizer error: {0}")] Tokenizer(#[from] tokenizers::Error), #[error("Backend error: {0}")] Backend(#[from] BackendError), #[error("WebServer error: {0}")] WebServer(#[from] server::WebServerError), #[error("Recv error: {0}")] RecvError(#[from] RecvError), }