From 02e4b9ab32dbe08992447c23537b9fbd530abb36 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 24 Jan 2025 10:41:07 +0100 Subject: [PATCH] backend(vllm): plug in the tokio server and CLI --- Cargo.lock | 1 + backends/vllm/Cargo.toml | 3 +- backends/vllm/src/errors.rs | 12 +++- backends/vllm/src/lib.rs | 1 + backends/vllm/src/main.rs | 113 ++++++++++++++++++++++++++++++++---- 5 files changed, 116 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 65cb0561..80ea70bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4448,6 +4448,7 @@ name = "text-generation-backends-vllm" version = "3.0.2-dev0" dependencies = [ "async-trait", + "clap 4.5.21", "pyo3", "text-generation-router", "thiserror 2.0.11", diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index bd1a21a1..0ab22b47 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -6,9 +6,10 @@ authors.workspace = true homepage.workspace = true [dependencies] +async-trait = "0.1.83" +clap = { version = "4.5.21", features = ["derive"] } pyo3 = { workspace = true } text-generation-router = { path = "../../router" } thiserror = "2.0" tokio = { version = "1.43", features = ["full"] } tokio-stream = "0.1" -async-trait = "0.1.83" diff --git a/backends/vllm/src/errors.rs b/backends/vllm/src/errors.rs index fa7d4414..aa008190 100644 --- a/backends/vllm/src/errors.rs +++ b/backends/vllm/src/errors.rs @@ -1,10 +1,14 @@ use pyo3::PyErr; +use text_generation_router::server::WebServerError; use thiserror::Error; #[derive(Debug, Error)] pub enum VllmBackendError { - #[error("{0}")] + #[error("[Python] {0}")] Python(PyErr), + + #[error("[WebServer] {0}")] + WebServer(WebServerError), } impl From for VllmBackendError { @@ -12,3 +16,9 @@ impl From for VllmBackendError { Self::Python(value) } } + +impl From for VllmBackendError { + fn from(value: WebServerError) -> Self { + Self::WebServer(value) + } +} diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs index 5ab448fd..37d5eb25 100644 --- a/backends/vllm/src/lib.rs +++ b/backends/vllm/src/lib.rs @@ -4,3 +4,4 @@ mod errors; pub use backend::VllmBackend; pub use engine::{EngineArgs, LlmEngine}; +pub use errors::VllmBackendError; diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs index 66597247..20b7efc9 100644 --- a/backends/vllm/src/main.rs +++ b/backends/vllm/src/main.rs @@ -1,17 +1,106 @@ -use text_generation_backends_vllm::{EngineArgs, LlmEngine}; +use clap::Parser; +use text_generation_backends_vllm::{EngineArgs, VllmBackend, VllmBackendError}; +use text_generation_router::{server, usage_stats}; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(long, env)] + max_input_tokens: Option, + #[clap(long, env)] + max_total_tokens: Option, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, + #[clap(default_value = "2000000", long, env)] + payload_limit: usize, +} + +impl Into for &Args { + fn into(self) -> EngineArgs { + EngineArgs { + model: self.tokenizer_name.clone(), + pipeline_parallel_size: 1, // TODO + tensor_parallel_size: 1, // TODO + } + } +} #[tokio::main] -async fn main() -> Result<(), ()> { - let args = EngineArgs { - model: String::from("meta-llama/Llama-3.2-1B-Instruct"), - pipeline_parallel_size: 1, - tensor_parallel_size: 1, - }; - - match LlmEngine::from_engine_args(args) { - Ok(_) => println!("Engine successfully allocated"), - Err(err) => println!("Got an error: {}", err), - } +async fn main() -> Result<(), VllmBackendError> { + let args = Args::parse(); + let backend = VllmBackend::from_engine_args((&args).into())?; + server::run( + backend, + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens.unwrap_or(1024), // TODO + args.max_total_tokens.unwrap_or(2048), // TODO + args.validation_workers, + args.api_key, + args.tokenizer_name, + args.tokenizer_config_path, + args.revision, + args.trust_remote_code, + args.hostname, + args.port, + args.cors_allow_origin, + false, + None, + None, + args.disable_grammar_support, + args.max_batch_size.unwrap_or(16), + args.usage_stats, + args.payload_limit, + ) + .await?; Ok(()) }