mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
backend(vllm): plug in the tokio server and CLI
This commit is contained in:
parent
bd2ec03d53
commit
02e4b9ab32
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4448,6 +4448,7 @@ name = "text-generation-backends-vllm"
|
|||||||
version = "3.0.2-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"clap 4.5.21",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror 2.0.11",
|
"thiserror 2.0.11",
|
||||||
|
@ -6,9 +6,10 @@ authors.workspace = true
|
|||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait = "0.1.83"
|
||||||
|
clap = { version = "4.5.21", features = ["derive"] }
|
||||||
pyo3 = { workspace = true }
|
pyo3 = { workspace = true }
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
tokio = { version = "1.43", features = ["full"] }
|
tokio = { version = "1.43", features = ["full"] }
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
async-trait = "0.1.83"
|
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
use pyo3::PyErr;
|
use pyo3::PyErr;
|
||||||
|
use text_generation_router::server::WebServerError;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum VllmBackendError {
|
pub enum VllmBackendError {
|
||||||
#[error("{0}")]
|
#[error("[Python] {0}")]
|
||||||
Python(PyErr),
|
Python(PyErr),
|
||||||
|
|
||||||
|
#[error("[WebServer] {0}")]
|
||||||
|
WebServer(WebServerError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<PyErr> for VllmBackendError {
|
impl From<PyErr> for VllmBackendError {
|
||||||
@ -12,3 +16,9 @@ impl From<PyErr> for VllmBackendError {
|
|||||||
Self::Python(value)
|
Self::Python(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<WebServerError> for VllmBackendError {
|
||||||
|
fn from(value: WebServerError) -> Self {
|
||||||
|
Self::WebServer(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,3 +4,4 @@ mod errors;
|
|||||||
|
|
||||||
pub use backend::VllmBackend;
|
pub use backend::VllmBackend;
|
||||||
pub use engine::{EngineArgs, LlmEngine};
|
pub use engine::{EngineArgs, LlmEngine};
|
||||||
|
pub use errors::VllmBackendError;
|
||||||
|
@ -1,17 +1,106 @@
|
|||||||
use text_generation_backends_vllm::{EngineArgs, LlmEngine};
|
use clap::Parser;
|
||||||
|
use text_generation_backends_vllm::{EngineArgs, VllmBackend, VllmBackendError};
|
||||||
|
use text_generation_router::{server, usage_stats};
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_input_tokens: Option<usize>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_total_tokens: Option<usize>,
|
||||||
|
#[clap(default_value = "1.2", long, env)]
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "20", long, env)]
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
trust_remote_code: bool,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
#[clap(default_value = "2000000", long, env)]
|
||||||
|
payload_limit: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Into<EngineArgs> for &Args {
|
||||||
|
fn into(self) -> EngineArgs {
|
||||||
|
EngineArgs {
|
||||||
|
model: self.tokenizer_name.clone(),
|
||||||
|
pipeline_parallel_size: 1, // TODO
|
||||||
|
tensor_parallel_size: 1, // TODO
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), ()> {
|
async fn main() -> Result<(), VllmBackendError> {
|
||||||
let args = EngineArgs {
|
let args = Args::parse();
|
||||||
model: String::from("meta-llama/Llama-3.2-1B-Instruct"),
|
let backend = VllmBackend::from_engine_args((&args).into())?;
|
||||||
pipeline_parallel_size: 1,
|
|
||||||
tensor_parallel_size: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
match LlmEngine::from_engine_args(args) {
|
|
||||||
Ok(_) => println!("Engine successfully allocated"),
|
|
||||||
Err(err) => println!("Got an error: {}", err),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
args.max_concurrent_requests,
|
||||||
|
args.max_best_of,
|
||||||
|
args.max_stop_sequences,
|
||||||
|
args.max_top_n_tokens,
|
||||||
|
args.max_input_tokens.unwrap_or(1024), // TODO
|
||||||
|
args.max_total_tokens.unwrap_or(2048), // TODO
|
||||||
|
args.validation_workers,
|
||||||
|
args.api_key,
|
||||||
|
args.tokenizer_name,
|
||||||
|
args.tokenizer_config_path,
|
||||||
|
args.revision,
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.hostname,
|
||||||
|
args.port,
|
||||||
|
args.cors_allow_origin,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
args.disable_grammar_support,
|
||||||
|
args.max_batch_size.unwrap_or(16),
|
||||||
|
args.usage_stats,
|
||||||
|
args.payload_limit,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user