backend(vllm): plug in the tokio server and CLI

This commit is contained in:
Morgan Funtowicz 2025-01-24 10:41:07 +01:00
parent bd2ec03d53
commit 02e4b9ab32
5 changed files with 116 additions and 14 deletions

1
Cargo.lock generated
View File

@ -4448,6 +4448,7 @@ name = "text-generation-backends-vllm"
version = "3.0.2-dev0"
dependencies = [
"async-trait",
"clap 4.5.21",
"pyo3",
"text-generation-router",
"thiserror 2.0.11",

View File

@ -6,9 +6,10 @@ authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "0.1.83"
clap = { version = "4.5.21", features = ["derive"] }
pyo3 = { workspace = true }
text-generation-router = { path = "../../router" }
thiserror = "2.0"
tokio = { version = "1.43", features = ["full"] }
tokio-stream = "0.1"
async-trait = "0.1.83"

View File

@ -1,10 +1,14 @@
use pyo3::PyErr;
use text_generation_router::server::WebServerError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum VllmBackendError {
#[error("{0}")]
#[error("[Python] {0}")]
Python(PyErr),
#[error("[WebServer] {0}")]
WebServer(WebServerError),
}
impl From<PyErr> for VllmBackendError {
@ -12,3 +16,9 @@ impl From<PyErr> for VllmBackendError {
Self::Python(value)
}
}
impl From<WebServerError> for VllmBackendError {
fn from(value: WebServerError) -> Self {
Self::WebServer(value)
}
}

View File

@ -4,3 +4,4 @@ mod errors;
pub use backend::VllmBackend;
pub use engine::{EngineArgs, LlmEngine};
pub use errors::VllmBackendError;

View File

@ -1,17 +1,106 @@
use text_generation_backends_vllm::{EngineArgs, LlmEngine};
use clap::Parser;
use text_generation_backends_vllm::{EngineArgs, VllmBackend, VllmBackendError};
use text_generation_router::{server, usage_stats};
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(long, env)]
max_input_tokens: Option<usize>,
#[clap(long, env)]
max_total_tokens: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
impl Into<EngineArgs> for &Args {
fn into(self) -> EngineArgs {
EngineArgs {
model: self.tokenizer_name.clone(),
pipeline_parallel_size: 1, // TODO
tensor_parallel_size: 1, // TODO
}
}
}
#[tokio::main]
async fn main() -> Result<(), ()> {
let args = EngineArgs {
model: String::from("meta-llama/Llama-3.2-1B-Instruct"),
pipeline_parallel_size: 1,
tensor_parallel_size: 1,
};
match LlmEngine::from_engine_args(args) {
Ok(_) => println!("Engine successfully allocated"),
Err(err) => println!("Got an error: {}", err),
}
async fn main() -> Result<(), VllmBackendError> {
let args = Args::parse();
let backend = VllmBackend::from_engine_args((&args).into())?;
server::run(
backend,
args.max_concurrent_requests,
args.max_best_of,
args.max_stop_sequences,
args.max_top_n_tokens,
args.max_input_tokens.unwrap_or(1024), // TODO
args.max_total_tokens.unwrap_or(2048), // TODO
args.validation_workers,
args.api_key,
args.tokenizer_name,
args.tokenizer_config_path,
args.revision,
args.trust_remote_code,
args.hostname,
args.port,
args.cors_allow_origin,
false,
None,
None,
args.disable_grammar_support,
args.max_batch_size.unwrap_or(16),
args.usage_stats,
args.payload_limit,
)
.await?;
Ok(())
}