mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 03:52:08 +00:00
backend(vllm): plug in the tokio server and CLI
This commit is contained in:
parent
bd2ec03d53
commit
02e4b9ab32
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4448,6 +4448,7 @@ name = "text-generation-backends-vllm"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"clap 4.5.21",
|
||||
"pyo3",
|
||||
"text-generation-router",
|
||||
"thiserror 2.0.11",
|
||||
|
@ -6,9 +6,10 @@ authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.83"
|
||||
clap = { version = "4.5.21", features = ["derive"] }
|
||||
pyo3 = { workspace = true }
|
||||
text-generation-router = { path = "../../router" }
|
||||
thiserror = "2.0"
|
||||
tokio = { version = "1.43", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
async-trait = "0.1.83"
|
||||
|
@ -1,10 +1,14 @@
|
||||
use pyo3::PyErr;
|
||||
use text_generation_router::server::WebServerError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum VllmBackendError {
|
||||
#[error("{0}")]
|
||||
#[error("[Python] {0}")]
|
||||
Python(PyErr),
|
||||
|
||||
#[error("[WebServer] {0}")]
|
||||
WebServer(WebServerError),
|
||||
}
|
||||
|
||||
impl From<PyErr> for VllmBackendError {
|
||||
@ -12,3 +16,9 @@ impl From<PyErr> for VllmBackendError {
|
||||
Self::Python(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WebServerError> for VllmBackendError {
|
||||
fn from(value: WebServerError) -> Self {
|
||||
Self::WebServer(value)
|
||||
}
|
||||
}
|
||||
|
@ -4,3 +4,4 @@ mod errors;
|
||||
|
||||
pub use backend::VllmBackend;
|
||||
pub use engine::{EngineArgs, LlmEngine};
|
||||
pub use errors::VllmBackendError;
|
||||
|
@ -1,17 +1,106 @@
|
||||
use text_generation_backends_vllm::{EngineArgs, LlmEngine};
|
||||
use clap::Parser;
|
||||
use text_generation_backends_vllm::{EngineArgs, VllmBackend, VllmBackendError};
|
||||
use text_generation_router::{server, usage_stats};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_input_tokens: Option<usize>,
|
||||
#[clap(long, env)]
|
||||
max_total_tokens: Option<usize>,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(long, env, value_enum)]
|
||||
trust_remote_code: bool,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
#[clap(default_value = "2000000", long, env)]
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
impl Into<EngineArgs> for &Args {
|
||||
fn into(self) -> EngineArgs {
|
||||
EngineArgs {
|
||||
model: self.tokenizer_name.clone(),
|
||||
pipeline_parallel_size: 1, // TODO
|
||||
tensor_parallel_size: 1, // TODO
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), ()> {
|
||||
let args = EngineArgs {
|
||||
model: String::from("meta-llama/Llama-3.2-1B-Instruct"),
|
||||
pipeline_parallel_size: 1,
|
||||
tensor_parallel_size: 1,
|
||||
};
|
||||
|
||||
match LlmEngine::from_engine_args(args) {
|
||||
Ok(_) => println!("Engine successfully allocated"),
|
||||
Err(err) => println!("Got an error: {}", err),
|
||||
}
|
||||
async fn main() -> Result<(), VllmBackendError> {
|
||||
let args = Args::parse();
|
||||
let backend = VllmBackend::from_engine_args((&args).into())?;
|
||||
|
||||
server::run(
|
||||
backend,
|
||||
args.max_concurrent_requests,
|
||||
args.max_best_of,
|
||||
args.max_stop_sequences,
|
||||
args.max_top_n_tokens,
|
||||
args.max_input_tokens.unwrap_or(1024), // TODO
|
||||
args.max_total_tokens.unwrap_or(2048), // TODO
|
||||
args.validation_workers,
|
||||
args.api_key,
|
||||
args.tokenizer_name,
|
||||
args.tokenizer_config_path,
|
||||
args.revision,
|
||||
args.trust_remote_code,
|
||||
args.hostname,
|
||||
args.port,
|
||||
args.cors_allow_origin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
args.disable_grammar_support,
|
||||
args.max_batch_size.unwrap_or(16),
|
||||
args.usage_stats,
|
||||
args.payload_limit,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user