text-generation-inference/backends/llamacpp/src/main.rs
Adrien Gallouët 95e221eece
Add llamacpp backend
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-02-04 13:32:56 +00:00

211 lines
6.1 KiB
Rust

mod backend;
use backend::{LlamacppConfig, LlamacppBackend, BackendError};
use clap::{Parser};
use text_generation_router::{logging, server, usage_stats};
use thiserror::Error;
use tokenizers::{Tokenizer, FromPretrainedParameters};
use tokio::sync::oneshot::error::RecvError;
use tracing::error;
/// Backend Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
/// Name of the model to load.
#[clap(long, env)]
model_id: String,
/// Revision of the model.
#[clap(default_value = "main", long, env)]
revision: String,
/// Path to the GGUF model file to be used for inference.
#[clap(long, env)]
model_gguf: String, // TODO Option() with hf->gguf & quantize
/// Context size for the model.
#[clap(default_value = "4096", long, env)]
n_ctx: u32,
/// Number of threads to use for inference.
#[clap(default_value = "1", long, env)]
n_threads: i32,
#[clap(default_value = "true", long, env)]
/// Whether to use memory mapping.
use_mmap: bool,
#[clap(default_value = "false", long, env)]
/// Whether to use memory locking.
use_mlock: bool,
/// Enable flash attention for faster inference. (EXPERIMENTAL)
#[clap(default_value = "false", long, env)]
flash_attention: bool,
/// TODO
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
/// Maximum number of input tokens allowed per request.
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
/// Maximum total tokens (input + output) allowed per request.
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
// #[clap(default_value = "1.2", long, env)]
// waiting_served_ratio: f32,
// #[clap(default_value = "4096", long, env)]
// max_batch_prefill_tokens: u32,
// #[clap(long, env)]
// max_batch_total_tokens: Option<u32>,
// #[clap(default_value = "20", long, env)]
// max_waiting_tokens: usize,
// #[clap(long, env)]
// max_batch_size: Option<usize>,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
/// The port to listen on.
#[clap(default_value = "3001", long, short, env)]
port: u16,
// #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
// master_shard_uds_path: String,
// #[clap(long, env)]
// tokenizer_name: String,
// #[clap(long, env)]
// tokenizer_config_path: Option<String>,
// #[clap(long, env, value_enum)]
// trust_remote_code: bool,
// #[clap(long, env)]
// api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
logging::init_logging(
args.otlp_endpoint,
args.otlp_service_name,
args.json_output
);
if args.max_input_tokens >= args.max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
// TODO: check if we use the same cache of Server
// check if llamacpp is faster
let tokenizer = {
let token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
let params = FromPretrainedParameters {
revision: args.revision.clone(),
token: token,
..Default::default()
};
Tokenizer::from_pretrained(
args.model_id.clone(),
Some(params)
)?
};
let (backend, ok) = LlamacppBackend::new(
LlamacppConfig {
model_gguf: args.model_gguf,
n_ctx: args.n_ctx,
n_threads: args.n_threads,
use_mmap: args.use_mmap,
use_mlock: args.use_mlock,
flash_attention: args.flash_attention,
},
tokenizer,
);
ok.await??;
server::run(
backend,
args.max_concurrent_requests,
args.max_best_of,
args.max_stop_sequences,
args.max_top_n_tokens,
args.max_input_tokens,
args.max_total_tokens,
args.validation_workers,
None, // api_key
args.model_id, // tokenizer_name
args.tokenizer_config_path,
Some(args.revision),
false, // trust_remote_code
args.hostname,
args.port,
args.cors_allow_origin,
args.ngrok,
args.ngrok_authtoken,
args.ngrok_edge,
args.disable_grammar_support,
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Tokenizer error: {0}")]
Tokenizer(#[from] tokenizers::Error),
#[error("Backend error: {0}")]
Backend(#[from] BackendError),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Recv error: {0}")]
RecvError(#[from] RecvError),
}