From cfd22726c914285c9f9036695f9e80112d854c70 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 21 Jan 2025 23:37:56 +0100 Subject: [PATCH 1/9] backend(vllm): initial commit --- Cargo.lock | 6 +++++- Cargo.toml | 7 +++++-- backends/vllm/Cargo.toml | 10 ++++++++++ backends/vllm/src/main.rs | 6 ++++++ 4 files changed, 26 insertions(+), 3 deletions(-) create mode 100644 backends/vllm/Cargo.toml create mode 100644 backends/vllm/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index e63d1540..0059976b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -4443,6 +4443,10 @@ dependencies = [ "tracing", ] +[[package]] +name = "text-generation-backends-vllm" +version = "3.0.2-dev0" + [[package]] name = "text-generation-benchmark" version = "3.0.2-dev0" diff --git a/Cargo.toml b/Cargo.toml index 9f49c9ab..4183614e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,15 +5,18 @@ members = [ "backends/v3", "backends/grpc-metadata", "backends/trtllm", + "backends/vllm", "launcher", - "router" + "router", ] + default-members = [ "benchmark", "backends/v2", "backends/v3", "backends/grpc-metadata", # "backends/trtllm", + # "backends/vllm", "launcher", "router" ] @@ -33,7 +36,7 @@ metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } -pyo3 = { version = "0.22.2", features = ["auto-initialize"] } +pyo3 = { version = "0.23", features = ["auto-initialize"] } [profile.release] incremental = true diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml new file mode 100644 index 00000000..b738745e --- /dev/null +++ b/backends/vllm/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "text-generation-backends-vllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +pyo3 = "0.23" +pyo3-asyncio = "0.20" \ No newline at end of file diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs new file mode 100644 index 00000000..fd54d8b1 --- /dev/null +++ b/backends/vllm/src/main.rs @@ -0,0 +1,6 @@ +use pyo3::prelude::*; + +#[pyo3_asyncio::tokio::main(flavor = "multi_thread")] +async fn main() { + println!("Hello, world!"); +} From bd2ec03d532c90c78c3cd242321cc10e9253d3de Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 22 Jan 2025 22:15:33 +0100 Subject: [PATCH 2/9] backend(vllm): statically allocate LLMEngine --- Cargo.lock | 84 ++++++++++++++++++++++----------- Cargo.toml | 2 +- backends/vllm/Cargo.toml | 8 +++- backends/vllm/src/backend.rs | 32 +++++++++++++ backends/vllm/src/engine.rs | 90 ++++++++++++++++++++++++++++++++++++ backends/vllm/src/errors.rs | 14 ++++++ backends/vllm/src/lib.rs | 6 +++ backends/vllm/src/main.rs | 19 ++++++-- router/Cargo.toml | 18 ++++---- 9 files changed, 229 insertions(+), 44 deletions(-) create mode 100644 backends/vllm/src/backend.rs create mode 100644 backends/vllm/src/engine.rs create mode 100644 backends/vllm/src/errors.rs create mode 100644 backends/vllm/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 0059976b..65cb0561 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -579,7 +579,7 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1601,7 +1601,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "ureq", ] @@ -2034,7 +2034,7 @@ checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17" dependencies = [ "opentelemetry 0.20.0", "opentelemetry-otlp", - "thiserror", + "thiserror 1.0.69", "tracing", "tracing-opentelemetry 0.21.0", ] @@ -2187,9 +2187,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.164" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libfuzzer-sys" @@ -2370,7 +2370,7 @@ dependencies = [ "metrics", "metrics-util", "quanta", - "thiserror", + "thiserror 1.0.69", "tokio", "tracing", ] @@ -2501,7 +2501,7 @@ dependencies = [ "futures", "pin-project", "rand", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-util", "tracing", @@ -2553,7 +2553,7 @@ dependencies = [ "rustls-pemfile", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-retry", "tokio-util", @@ -2857,7 +2857,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.69", "urlencoding", ] @@ -2875,7 +2875,7 @@ dependencies = [ "opentelemetry_api", "opentelemetry_sdk 0.20.0", "prost 0.11.9", - "thiserror", + "thiserror 1.0.69", "tokio", "tonic 0.9.2", ] @@ -2913,7 +2913,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.69", "urlencoding", ] @@ -2935,7 +2935,7 @@ dependencies = [ "rand", "regex", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-stream", ] @@ -2957,7 +2957,7 @@ dependencies = [ "ordered-float 4.5.0", "percent-encoding", "rand", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -3494,7 +3494,7 @@ dependencies = [ "rand_chacha", "simd_helpers", "system-deps", - "thiserror", + "thiserror 1.0.69", "v_frame", "wasm-bindgen", ] @@ -3571,7 +3571,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -4436,7 +4436,7 @@ dependencies = [ "pkg-config", "pyo3", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4446,6 +4446,14 @@ dependencies = [ [[package]] name = "text-generation-backends-vllm" version = "3.0.2-dev0" +dependencies = [ + "async-trait", + "pyo3", + "text-generation-router", + "thiserror 2.0.11", + "tokio", + "tokio-stream", +] [[package]] name = "text-generation-benchmark" @@ -4460,7 +4468,7 @@ dependencies = [ "serde_json", "tabled", "text-generation-client", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tracing", @@ -4477,7 +4485,7 @@ dependencies = [ "grpc-metadata", "prost 0.12.6", "prost-build", - "thiserror", + "thiserror 1.0.69", "tokio", "tonic 0.10.2", "tonic-build", @@ -4500,7 +4508,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tracing", "tracing-subscriber", "vergen", @@ -4542,7 +4550,7 @@ dependencies = [ "serde", "serde_json", "sysinfo", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4591,7 +4599,7 @@ dependencies = [ "serde_json", "slotmap", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4642,7 +4650,7 @@ dependencies = [ "serde_json", "slotmap", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4672,7 +4680,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl 2.0.11", ] [[package]] @@ -4686,6 +4703,17 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -4787,7 +4815,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror", + "thiserror 1.0.69", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -4795,9 +4823,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.1" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -4823,9 +4851,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 4183614e..170026a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } -pyo3 = { version = "0.23", features = ["auto-initialize"] } +pyo3 = { version = "0.22", features = ["auto-initialize"] } [profile.release] incremental = true diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index b738745e..bd1a21a1 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -6,5 +6,9 @@ authors.workspace = true homepage.workspace = true [dependencies] -pyo3 = "0.23" -pyo3-asyncio = "0.20" \ No newline at end of file +pyo3 = { workspace = true } +text-generation-router = { path = "../../router" } +thiserror = "2.0" +tokio = { version = "1.43", features = ["full"] } +tokio-stream = "0.1" +async-trait = "0.1.83" diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs new file mode 100644 index 00000000..6d49c268 --- /dev/null +++ b/backends/vllm/src/backend.rs @@ -0,0 +1,32 @@ +use crate::errors::VllmBackendError; +use crate::{EngineArgs, LlmEngine}; +use async_trait::async_trait; +use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use tokio_stream::wrappers::UnboundedReceiverStream; + +pub struct VllmBackend { + engine: LlmEngine, +} + +impl VllmBackend { + pub fn from_engine_args(args: EngineArgs) -> Result { + Ok(Self { + engine: LlmEngine::from_engine_args(args)?, + }) + } +} + +#[async_trait] +impl Backend for VllmBackend { + fn schedule( + &self, + _request: ValidGenerateRequest, + ) -> Result>, InferError> { + todo!() + } + + async fn health(&self, _current_health: bool) -> bool { + true + } +} diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs new file mode 100644 index 00000000..1debe4c5 --- /dev/null +++ b/backends/vllm/src/engine.rs @@ -0,0 +1,90 @@ +use crate::errors::VllmBackendError; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyDict, PyList}; + +pub struct EngineArgs { + pub model: String, + pub pipeline_parallel_size: u32, + pub tensor_parallel_size: u32, +} + +impl IntoPyDict for EngineArgs { + fn into_py_dict_bound(self, py: Python<'_>) -> Bound<'_, PyDict> { + PyDict::from_sequence_bound( + PyList::new_bound( + py, + [ + ("model", self.model.into_py(py)), + ( + "pipeline_parallel_size", + self.pipeline_parallel_size.into_py(py), + ), + ( + "tensor_parallel_size", + self.tensor_parallel_size.into_py(py), + ), + ], + ) + .as_any(), + ) + .expect("Failed to create Python Dict from EngineArgs") + } +} + +// impl IntoPy for EngineArgs { +// fn into_py(self, py: Python<'_>) -> PyObject { +// PyDict::from_sequence_bound( +// PyList::new_bound( +// py, +// [ +// ("model", self.model.into_py(py)), +// ( +// "pipeline_parallel_size", +// self.pipeline_parallel_size.into_py(py), +// ), +// ( +// "tensor_parallel_size", +// self.tensor_parallel_size.into_py(py), +// ), +// ], +// ) +// .as_any(), +// ) +// .expect("Failed to create Python Dict from EngineArgs") +// } +// } + +pub struct LlmEngine { + engine: PyObject, +} + +impl LlmEngine { + fn py_from_engine_args(args: EngineArgs) -> PyResult { + Python::with_gil(|py| { + // Create the EngineArgs from Rust + // from vllm.engine.arg_util import EngineArgs + // engine_args = EngineArgs(**args) + let py_engine_args_mod = PyModule::import_bound(py, "vllm.engine.arg_utils")?; + let py_engine_args_class = py_engine_args_mod.getattr("EngineArgs")?; + let py_engine_args = + py_engine_args_class.call((), Some(&args.into_py_dict_bound(py)))?; + + // Next create the LLMEngine from the EngineArgs + // from vllm.engine.llm_engine import LLMEngine + // engine = LLMEngine.from_engine_args(engine_args) + let py_engine_llm_mod = PyModule::import_bound(py, "vllm.engine.llm_engine")?; + let py_engine_llm_class = py_engine_llm_mod.getattr("LLMEngine")?; + py_engine_llm_class + .call_method("from_engine_args", (py_engine_args,), None)? + .extract() + }) + } + + pub fn from_engine_args(args: EngineArgs) -> Result { + let engine = Self::py_from_engine_args(args)?; + + Ok(Self { engine }) + } + + pub fn step(&mut self) {} +} diff --git a/backends/vllm/src/errors.rs b/backends/vllm/src/errors.rs new file mode 100644 index 00000000..fa7d4414 --- /dev/null +++ b/backends/vllm/src/errors.rs @@ -0,0 +1,14 @@ +use pyo3::PyErr; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum VllmBackendError { + #[error("{0}")] + Python(PyErr), +} + +impl From for VllmBackendError { + fn from(value: PyErr) -> Self { + Self::Python(value) + } +} diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs new file mode 100644 index 00000000..5ab448fd --- /dev/null +++ b/backends/vllm/src/lib.rs @@ -0,0 +1,6 @@ +mod backend; +mod engine; +mod errors; + +pub use backend::VllmBackend; +pub use engine::{EngineArgs, LlmEngine}; diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs index fd54d8b1..66597247 100644 --- a/backends/vllm/src/main.rs +++ b/backends/vllm/src/main.rs @@ -1,6 +1,17 @@ -use pyo3::prelude::*; +use text_generation_backends_vllm::{EngineArgs, LlmEngine}; -#[pyo3_asyncio::tokio::main(flavor = "multi_thread")] -async fn main() { - println!("Hello, world!"); +#[tokio::main] +async fn main() -> Result<(), ()> { + let args = EngineArgs { + model: String::from("meta-llama/Llama-3.2-1B-Instruct"), + pipeline_parallel_size: 1, + tensor_parallel_size: 1, + }; + + match LlmEngine::from_engine_args(args) { + Ok(_) => println!("Engine successfully allocated"), + Err(err) => println!("Got an error: {}", err), + } + + Ok(()) } diff --git a/router/Cargo.toml b/router/Cargo.toml index 2e621dfc..81c58616 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -31,11 +31,11 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } @@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true } minijinja-contrib = { workspace = true } @@ -57,9 +57,9 @@ image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = [ - "v4", - "fast-rng", - "macro-diagnostics", + "v4", + "fast-rng", + "macro-diagnostics", ] } csv = "1.3.0" ureq = "=2.9" From 02e4b9ab32dbe08992447c23537b9fbd530abb36 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 24 Jan 2025 10:41:07 +0100 Subject: [PATCH 3/9] backend(vllm): plug in the tokio server and CLI --- Cargo.lock | 1 + backends/vllm/Cargo.toml | 3 +- backends/vllm/src/errors.rs | 12 +++- backends/vllm/src/lib.rs | 1 + backends/vllm/src/main.rs | 113 ++++++++++++++++++++++++++++++++---- 5 files changed, 116 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 65cb0561..80ea70bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4448,6 +4448,7 @@ name = "text-generation-backends-vllm" version = "3.0.2-dev0" dependencies = [ "async-trait", + "clap 4.5.21", "pyo3", "text-generation-router", "thiserror 2.0.11", diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index bd1a21a1..0ab22b47 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -6,9 +6,10 @@ authors.workspace = true homepage.workspace = true [dependencies] +async-trait = "0.1.83" +clap = { version = "4.5.21", features = ["derive"] } pyo3 = { workspace = true } text-generation-router = { path = "../../router" } thiserror = "2.0" tokio = { version = "1.43", features = ["full"] } tokio-stream = "0.1" -async-trait = "0.1.83" diff --git a/backends/vllm/src/errors.rs b/backends/vllm/src/errors.rs index fa7d4414..aa008190 100644 --- a/backends/vllm/src/errors.rs +++ b/backends/vllm/src/errors.rs @@ -1,10 +1,14 @@ use pyo3::PyErr; +use text_generation_router::server::WebServerError; use thiserror::Error; #[derive(Debug, Error)] pub enum VllmBackendError { - #[error("{0}")] + #[error("[Python] {0}")] Python(PyErr), + + #[error("[WebServer] {0}")] + WebServer(WebServerError), } impl From for VllmBackendError { @@ -12,3 +16,9 @@ impl From for VllmBackendError { Self::Python(value) } } + +impl From for VllmBackendError { + fn from(value: WebServerError) -> Self { + Self::WebServer(value) + } +} diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs index 5ab448fd..37d5eb25 100644 --- a/backends/vllm/src/lib.rs +++ b/backends/vllm/src/lib.rs @@ -4,3 +4,4 @@ mod errors; pub use backend::VllmBackend; pub use engine::{EngineArgs, LlmEngine}; +pub use errors::VllmBackendError; diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs index 66597247..20b7efc9 100644 --- a/backends/vllm/src/main.rs +++ b/backends/vllm/src/main.rs @@ -1,17 +1,106 @@ -use text_generation_backends_vllm::{EngineArgs, LlmEngine}; +use clap::Parser; +use text_generation_backends_vllm::{EngineArgs, VllmBackend, VllmBackendError}; +use text_generation_router::{server, usage_stats}; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(long, env)] + max_input_tokens: Option, + #[clap(long, env)] + max_total_tokens: Option, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, + #[clap(default_value = "2000000", long, env)] + payload_limit: usize, +} + +impl Into for &Args { + fn into(self) -> EngineArgs { + EngineArgs { + model: self.tokenizer_name.clone(), + pipeline_parallel_size: 1, // TODO + tensor_parallel_size: 1, // TODO + } + } +} #[tokio::main] -async fn main() -> Result<(), ()> { - let args = EngineArgs { - model: String::from("meta-llama/Llama-3.2-1B-Instruct"), - pipeline_parallel_size: 1, - tensor_parallel_size: 1, - }; - - match LlmEngine::from_engine_args(args) { - Ok(_) => println!("Engine successfully allocated"), - Err(err) => println!("Got an error: {}", err), - } +async fn main() -> Result<(), VllmBackendError> { + let args = Args::parse(); + let backend = VllmBackend::from_engine_args((&args).into())?; + server::run( + backend, + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens.unwrap_or(1024), // TODO + args.max_total_tokens.unwrap_or(2048), // TODO + args.validation_workers, + args.api_key, + args.tokenizer_name, + args.tokenizer_config_path, + args.revision, + args.trust_remote_code, + args.hostname, + args.port, + args.cors_allow_origin, + false, + None, + None, + args.disable_grammar_support, + args.max_batch_size.unwrap_or(16), + args.usage_stats, + args.payload_limit, + ) + .await?; Ok(()) } From a7c2a470d67694acfc86c40fd913f5a4df966740 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 27 Jan 2025 22:39:35 +0100 Subject: [PATCH 4/9] backend(vllm): submit new request to vLLM engine --- Cargo.lock | 4 ++ backends/vllm/Cargo.toml | 4 ++ backends/vllm/src/backend.rs | 80 ++++++++++++++++++++-- backends/vllm/src/engine.rs | 125 ++++++++++++++++++++++++++++------- backends/vllm/src/errors.rs | 7 ++ backends/vllm/src/lib.rs | 33 +++++++++ backends/vllm/src/main.rs | 2 + 7 files changed, 227 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 80ea70bd..eb922162 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4449,11 +4449,15 @@ version = "3.0.2-dev0" dependencies = [ "async-trait", "clap 4.5.21", + "log", "pyo3", "text-generation-router", "thiserror 2.0.11", "tokio", "tokio-stream", + "tracing", + "tracing-subscriber", + "uuid", ] [[package]] diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index 0ab22b47..2308a655 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -13,3 +13,7 @@ text-generation-router = { path = "../../router" } thiserror = "2.0" tokio = { version = "1.43", features = ["full"] } tokio-stream = "0.1" +uuid = { version = "1.11.0", features = ["v4"] } +log = "0.4.22" +tracing = "0.1.40" +tracing-subscriber = "0.3.18" diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs index 6d49c268..0ccf8063 100644 --- a/backends/vllm/src/backend.rs +++ b/backends/vllm/src/backend.rs @@ -1,18 +1,39 @@ use crate::errors::VllmBackendError; use crate::{EngineArgs, LlmEngine}; use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use std::thread::{spawn, JoinHandle}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; -use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::validation::{ + ValidGenerateRequest, ValidParameters, ValidStoppingParameters, +}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{debug, info, warn}; + +type InferResult = Result; + +struct Request { + tokens: Arc>, + params: ValidParameters, + stopping_params: ValidStoppingParameters, + streamer: UnboundedSender, +} pub struct VllmBackend { - engine: LlmEngine, + looper: JoinHandle<()>, + waiting_requests: UnboundedSender, } impl VllmBackend { pub fn from_engine_args(args: EngineArgs) -> Result { + let engine = LlmEngine::from_engine_args(args)?; + let (sender, receiver) = unbounded_channel(); + let looper = spawn(|| engine_background_loop(engine, receiver)); Ok(Self { - engine: LlmEngine::from_engine_args(args)?, + looper, + waiting_requests: sender, }) } } @@ -21,12 +42,61 @@ impl VllmBackend { impl Backend for VllmBackend { fn schedule( &self, - _request: ValidGenerateRequest, + request: ValidGenerateRequest, ) -> Result>, InferError> { - todo!() + let (sender, receiver) = unbounded_channel(); + + // Send the query to the vLLM Engine + if let Some(input_ids) = request.input_ids { + debug!("Attempt to queue new request"); + if let Err(err) = self.waiting_requests.send(Request { + tokens: Arc::clone(&input_ids), + params: request.parameters, + stopping_params: request.stopping_parameters, + streamer: sender, + }) { + warn!("Waiting Requests queue has been closed: {err}") + } + }; + + Ok(UnboundedReceiverStream::new(receiver)) } async fn health(&self, _current_health: bool) -> bool { true } } + +fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver) { + info!("Starting vLLM engine background loop"); + + let mut in_flight_requests = HashMap::with_capacity(256); + loop { + if !waiting_requests.is_empty() { + let num_waiting_requests = waiting_requests.len(); + debug!( + "Adding {} requests to the vLLM engine", + num_waiting_requests + ); + + let mut requests = Vec::with_capacity(num_waiting_requests); + waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests); + + for request in requests { + match engine.add_request(&request.tokens, &request.params, &request.stopping_params) + { + Ok(request_id) => { + debug!("Successfully scheduled request {request_id}"); + in_flight_requests.insert(request_id.to_string(), request); + } + Err(err) => { + warn!("Failed to schedule new request: {err}"); + } + } + } + } + engine.step(); + } + + info!("Shutting down vLLM engine background loop"); +} diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index 1debe4c5..d4f4f5dc 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -1,6 +1,10 @@ use crate::errors::VllmBackendError; +use crate::{sampling_params, tokens_prompt, TryToPyObject}; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyDict, PyList}; +use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; +use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; +use tracing::info; +use uuid::Uuid; pub struct EngineArgs { pub model: String, @@ -31,28 +35,51 @@ impl IntoPyDict for EngineArgs { } } -// impl IntoPy for EngineArgs { -// fn into_py(self, py: Python<'_>) -> PyObject { -// PyDict::from_sequence_bound( -// PyList::new_bound( -// py, -// [ -// ("model", self.model.into_py(py)), -// ( -// "pipeline_parallel_size", -// self.pipeline_parallel_size.into_py(py), -// ), -// ( -// "tensor_parallel_size", -// self.tensor_parallel_size.into_py(py), -// ), -// ], -// ) -// .as_any(), -// ) -// .expect("Failed to create Python Dict from EngineArgs") -// } -// } +pub struct SamplingParams<'a> { + sampling_params: &'a ValidParameters, + stopping_params: &'a ValidStoppingParameters, +} + +impl TryToPyObject for SamplingParams<'_> { + fn try_to_object(&self, py: Python<'_>) -> Result { + let py_sampling_params_class = sampling_params(py); + + let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( + py, + [ + ("seed", self.sampling_params.seed.into_py(py)), + ("n", 1.into_py(py)), + ("top_k", self.sampling_params.top_k.into_py(py)), + ("top_p", self.sampling_params.top_p.into_py(py)), + ("temperature", self.sampling_params.temperature.into_py(py)), + ( + "frequency_penalty", + self.sampling_params.frequency_penalty.into_py(py), + ), + ( + "repetition_penalty", + self.sampling_params.repetition_penalty.into_py(py), + ), + ( + "ignore_eos", + self.stopping_params.ignore_eos_token.into_py(py), + ), + ( + "max_tokens", + self.stopping_params.max_new_tokens.into_py(py), + ), + ( + "stop", + PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(), + ), + ], + )); + + Ok(py_sampling_params_class + .call_method_bound(py, "from_optional", (), Some(&kwargs?))? + .to_object(py)) + } +} pub struct LlmEngine { engine: PyObject, @@ -80,11 +107,63 @@ impl LlmEngine { }) } + fn py_add_request( + &self, + request_id: &str, + prompt: &[u32], + sampling_params: SamplingParams, + ) -> Result<(), VllmBackendError> { + Python::with_gil(|py| { + // Create vllm.Tokens + let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py); + let py_tokens_prompt_class = tokens_prompt(py); + let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?; + let py_sampling_params = sampling_params.try_to_object(py)?; + + let _ = py.eval_bound( + "print(type(params), params)", + Some(&[("params", &py_sampling_params)].into_py_dict_bound(py)), + None, + ); + + self.engine.call_method1( + py, + "add_request", + ( + PyString::new_bound(py, request_id), + py_tokens_prompt, + py_sampling_params, + ), + )?; + + self.engine.call_method0(py, "step") + })?; + + Ok(()) + } + pub fn from_engine_args(args: EngineArgs) -> Result { let engine = Self::py_from_engine_args(args)?; Ok(Self { engine }) } + pub fn add_request( + &self, + prompt: &[u32], + sampling_params: &ValidParameters, + stopping_params: &ValidStoppingParameters, + ) -> Result { + let request_id = Uuid::new_v4(); + let sampling_params = SamplingParams { + sampling_params, + stopping_params, + }; + self.py_add_request(&request_id.to_string(), prompt, sampling_params)?; + + info!("Submitted new request: {request_id}"); + Ok(request_id) + } + pub fn step(&mut self) {} } diff --git a/backends/vllm/src/errors.rs b/backends/vllm/src/errors.rs index aa008190..1b03f5a4 100644 --- a/backends/vllm/src/errors.rs +++ b/backends/vllm/src/errors.rs @@ -1,4 +1,5 @@ use pyo3::PyErr; +use text_generation_router::infer::InferError; use text_generation_router::server::WebServerError; use thiserror::Error; @@ -22,3 +23,9 @@ impl From for VllmBackendError { Self::WebServer(value) } } + +impl From for InferError { + fn from(value: VllmBackendError) -> Self { + InferError::GenerationError(value.to_string()) + } +} diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs index 37d5eb25..12c910df 100644 --- a/backends/vllm/src/lib.rs +++ b/backends/vllm/src/lib.rs @@ -5,3 +5,36 @@ mod errors; pub use backend::VllmBackend; pub use engine::{EngineArgs, LlmEngine}; pub use errors::VllmBackendError; +use pyo3::prelude::PyAnyMethods; +use pyo3::sync::GILOnceCell; +use pyo3::types::PyModule; +use pyo3::{Py, PyAny, PyErr, PyObject, Python}; + +static PY_TOKENS_PROMPT_CLASS: GILOnceCell> = GILOnceCell::new(); +static PY_SAMPLING_PARAMS_CLASS: GILOnceCell> = GILOnceCell::new(); + +#[inline] +pub(crate) fn tokens_prompt(py: Python) -> &Py { + PY_TOKENS_PROMPT_CLASS.get_or_init(py, || { + PyModule::import_bound(py, "vllm.inputs") + .expect("Failed to import vllm.inputs") + .getattr("TokensPrompt") + .expect("Failed to import vllm.inputs.TokensPrompt") + .unbind() + }) +} + +#[inline] +pub(crate) fn sampling_params(py: Python) -> &Py { + PY_SAMPLING_PARAMS_CLASS.get_or_init(py, || { + PyModule::import_bound(py, "vllm") + .expect("Failed to import vllm") + .getattr("SamplingParams") + .expect("Failed to import vllm.SamplingParams") + .unbind() + }) +} + +pub(crate) trait TryToPyObject { + fn try_to_object(&self, py: Python<'_>) -> Result; +} diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs index 20b7efc9..55f47871 100644 --- a/backends/vllm/src/main.rs +++ b/backends/vllm/src/main.rs @@ -73,6 +73,8 @@ impl Into for &Args { #[tokio::main] async fn main() -> Result<(), VllmBackendError> { + tracing_subscriber::fmt::init(); + let args = Args::parse(); let backend = VllmBackend::from_engine_args((&args).into())?; From dc5addae813570bc694308cf42a37b885bf89d94 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 27 Jan 2025 22:43:16 +0100 Subject: [PATCH 5/9] backend(vllm): remove python print stmt --- backends/vllm/src/engine.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index d4f4f5dc..b7234f29 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -120,12 +120,6 @@ impl LlmEngine { let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?; let py_sampling_params = sampling_params.try_to_object(py)?; - let _ = py.eval_bound( - "print(type(params), params)", - Some(&[("params", &py_sampling_params)].into_py_dict_bound(py)), - None, - ); - self.engine.call_method1( py, "add_request", From 7028f5bce2b025b990516c7fee528a6bf798dc24 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 29 Jan 2025 17:01:20 +0100 Subject: [PATCH 6/9] backend(vllm): make v1 the default --- backends/vllm/src/engine.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index b7234f29..53e36e14 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -29,9 +29,9 @@ impl IntoPyDict for EngineArgs { ), ], ) - .as_any(), + .as_any(), ) - .expect("Failed to create Python Dict from EngineArgs") + .expect("Failed to create Python Dict from EngineArgs") } } @@ -99,7 +99,7 @@ impl LlmEngine { // Next create the LLMEngine from the EngineArgs // from vllm.engine.llm_engine import LLMEngine // engine = LLMEngine.from_engine_args(engine_args) - let py_engine_llm_mod = PyModule::import_bound(py, "vllm.engine.llm_engine")?; + let py_engine_llm_mod = PyModule::import_bound(py, "vllm.v1.engine.llm_engine")?; let py_engine_llm_class = py_engine_llm_mod.getattr("LLMEngine")?; py_engine_llm_class .call_method("from_engine_args", (py_engine_args,), None)? From 32dffcff60ad4a5a75cb5ed356a077d0d7a7e2c3 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 30 Jan 2025 13:35:21 +0100 Subject: [PATCH 7/9] backend(vllm): expose FFI for CompletionOutput and RequestOutput on Rust side --- Cargo.lock | 1 + backends/vllm/Cargo.toml | 1 + backends/vllm/src/backend.rs | 80 +++++++++++++++++++----------- backends/vllm/src/engine.rs | 94 +++++++++++++++++++++++++++++------- 4 files changed, 131 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb922162..70ecf1f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4449,6 +4449,7 @@ version = "3.0.2-dev0" dependencies = [ "async-trait", "clap 4.5.21", + "crossbeam-channel", "log", "pyo3", "text-generation-router", diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index 2308a655..c77f4562 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -8,6 +8,7 @@ homepage.workspace = true [dependencies] async-trait = "0.1.83" clap = { version = "4.5.21", features = ["derive"] } +crossbeam-channel = "0.5" pyo3 = { workspace = true } text-generation-router = { path = "../../router" } thiserror = "2.0" diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs index 0ccf8063..46419279 100644 --- a/backends/vllm/src/backend.rs +++ b/backends/vllm/src/backend.rs @@ -1,38 +1,42 @@ use crate::errors::VllmBackendError; use crate::{EngineArgs, LlmEngine}; use async_trait::async_trait; -use std::collections::HashMap; +use crossbeam_channel::internal::SelectHandle; +use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender}; +use std::collections::{HashMap, HashSet}; +use std::hint::spin_loop; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::thread::{spawn, JoinHandle}; +use std::time::Duration; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::validation::{ ValidGenerateRequest, ValidParameters, ValidStoppingParameters, }; +use text_generation_router::Token; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, info, warn}; +use tracing::{debug, error, info, warn}; type InferResult = Result; -struct Request { +struct VllmRequestContext { tokens: Arc>, params: ValidParameters, stopping_params: ValidStoppingParameters, - streamer: UnboundedSender, + stream: UnboundedSender, } pub struct VllmBackend { - looper: JoinHandle<()>, - waiting_requests: UnboundedSender, + waiting_requests: Sender, } impl VllmBackend { pub fn from_engine_args(args: EngineArgs) -> Result { let engine = LlmEngine::from_engine_args(args)?; - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = unbounded(); let looper = spawn(|| engine_background_loop(engine, receiver)); Ok(Self { - looper, waiting_requests: sender, }) } @@ -48,12 +52,12 @@ impl Backend for VllmBackend { // Send the query to the vLLM Engine if let Some(input_ids) = request.input_ids { - debug!("Attempt to queue new request"); - if let Err(err) = self.waiting_requests.send(Request { + debug!("Queuing new request"); + if let Err(err) = self.waiting_requests.send(VllmRequestContext { tokens: Arc::clone(&input_ids), params: request.parameters, stopping_params: request.stopping_parameters, - streamer: sender, + stream: sender, }) { warn!("Waiting Requests queue has been closed: {err}") } @@ -67,35 +71,55 @@ impl Backend for VllmBackend { } } -fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver) { +fn engine_background_loop( + mut engine: LlmEngine, + mut waiting_requests: Receiver, +) { info!("Starting vLLM engine background loop"); - + static DURATION_100_MS: Duration = Duration::from_millis(100); let mut in_flight_requests = HashMap::with_capacity(256); - loop { + 'outer: loop { if !waiting_requests.is_empty() { - let num_waiting_requests = waiting_requests.len(); - debug!( - "Adding {} requests to the vLLM engine", - num_waiting_requests - ); - - let mut requests = Vec::with_capacity(num_waiting_requests); - waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests); - - for request in requests { - match engine.add_request(&request.tokens, &request.params, &request.stopping_params) - { + match waiting_requests.recv_timeout(DURATION_100_MS) { + Ok(context) => match engine.add_request( + &context.tokens, + &context.params, + &context.stopping_params, + ) { Ok(request_id) => { debug!("Successfully scheduled request {request_id}"); - in_flight_requests.insert(request_id.to_string(), request); + in_flight_requests.insert(request_id.to_string(), context); } Err(err) => { warn!("Failed to schedule new request: {err}"); } + }, + Err(err) => match err { + RecvTimeoutError::Disconnected => break 'outer, + _ => {} // timeout all fine + }, + } + } + + if !in_flight_requests.is_empty() { + match engine.step() { + Ok(outputs) => outputs.iter().for_each(|output| { + let ctx = &in_flight_requests[&output.request_id]; + + // We only need to check on Err meaning the channel is not open anymore, so abort the request + if let Err(_) = ctx.stream.send(InferResult {}) { + debug!("Request {}'s channel dropped, aborting", &output.request_id); + in_flight_requests.remove(&output.request_id); + engine.abort_request(&output.request_id); + } + }), + Err(err) => { + error!("LLMEngine::step got an error: {err}"); } } } - engine.step(); + + spin_loop(); } info!("Shutting down vLLM engine background loop"); diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index 53e36e14..f3b6a761 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -1,9 +1,10 @@ use crate::errors::VllmBackendError; use crate::{sampling_params, tokens_prompt, TryToPyObject}; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; -use tracing::info; +use tracing::{info, instrument}; use uuid::Uuid; pub struct EngineArgs { @@ -29,9 +30,9 @@ impl IntoPyDict for EngineArgs { ), ], ) - .as_any(), + .as_any(), ) - .expect("Failed to create Python Dict from EngineArgs") + .expect("Failed to create Python Dict from EngineArgs") } } @@ -47,29 +48,32 @@ impl TryToPyObject for SamplingParams<'_> { let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( py, [ - ("seed", self.sampling_params.seed.into_py(py)), - ("n", 1.into_py(py)), - ("top_k", self.sampling_params.top_k.into_py(py)), - ("top_p", self.sampling_params.top_p.into_py(py)), - ("temperature", self.sampling_params.temperature.into_py(py)), + (intern!(py, "seed"), self.sampling_params.seed.into_py(py)), + (intern!(py, "n"), 1.into_py(py)), + (intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)), + (intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)), ( - "frequency_penalty", + intern!(py, "temperature"), + self.sampling_params.temperature.into_py(py), + ), + ( + intern!(py, "frequency_penalty"), self.sampling_params.frequency_penalty.into_py(py), ), ( - "repetition_penalty", + intern!(py, "repetition_penalty"), self.sampling_params.repetition_penalty.into_py(py), ), ( - "ignore_eos", + intern!(py, "ignore_eos"), self.stopping_params.ignore_eos_token.into_py(py), ), ( - "max_tokens", + intern!(py, "max_tokens"), self.stopping_params.max_new_tokens.into_py(py), ), ( - "stop", + intern!(py, "stop"), PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(), ), ], @@ -81,6 +85,47 @@ impl TryToPyObject for SamplingParams<'_> { } } +#[derive(Debug)] +pub struct CompletionOutput { + pub index: usize, + pub text: String, // TODO: SmallString? + pub token_ids: Vec, // TODO: TinyVec? + pub logprobs: Option>, // TODO: TinyVec? + pub finish_reason: Option, // lora_request: LATER +} + +#[derive(Debug)] +pub struct RequestOutput { + pub request_id: String, + pub outputs: Vec, + pub finished: bool, + // metrics: Vec // TODO +} + +impl<'py> FromPyObject<'py> for CompletionOutput { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + index: ob.getattr(intern!(py, "index"))?.extract()?, + text: ob.getattr(intern!(py, "text"))?.extract()?, + token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?, + logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, + finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?, + }) + } +} + +impl<'py> FromPyObject<'py> for RequestOutput { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + request_id: ob.getattr(intern!(py, "request_id"))?.extract()?, + outputs: ob.getattr(intern!(py, "outputs"))?.extract()?, + finished: ob.getattr(intern!(py, "finished"))?.extract()?, + }) + } +} + pub struct LlmEngine { engine: PyObject, } @@ -115,14 +160,14 @@ impl LlmEngine { ) -> Result<(), VllmBackendError> { Python::with_gil(|py| { // Create vllm.Tokens - let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py); + let kwargs = [(intern!(py, "prompt_token_ids"), prompt)].into_py_dict_bound(py); let py_tokens_prompt_class = tokens_prompt(py); let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?; let py_sampling_params = sampling_params.try_to_object(py)?; self.engine.call_method1( py, - "add_request", + intern!(py, "add_request"), ( PyString::new_bound(py, request_id), py_tokens_prompt, @@ -130,18 +175,27 @@ impl LlmEngine { ), )?; - self.engine.call_method0(py, "step") + self.engine.call_method0(py, intern!(py, "step")) })?; Ok(()) } + fn py_step(&self) -> Result, VllmBackendError> { + Ok(Python::with_gil(|py| { + self.engine + .call_method0(py, intern!(py, "step"))? + .extract::>(py) + })?) + } + pub fn from_engine_args(args: EngineArgs) -> Result { let engine = Self::py_from_engine_args(args)?; Ok(Self { engine }) } + #[instrument(skip_all)] pub fn add_request( &self, prompt: &[u32], @@ -159,5 +213,11 @@ impl LlmEngine { Ok(request_id) } - pub fn step(&mut self) {} + #[instrument(skip_all)] + pub fn abort_request(&self, _request_id: &str) {} + + #[instrument(skip_all)] + pub fn step(&mut self) -> Result, VllmBackendError> { + self.py_step() + } } From 003163a2b9a0e7e30d20bcc96a71955be850649b Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 30 Jan 2025 16:12:52 +0100 Subject: [PATCH 8/9] backend(vllm): map ResultOutput to InferStreamResponse to stream back to the client --- backends/vllm/src/backend.rs | 103 ++++++++++++++++++++++++++++------- backends/vllm/src/engine.rs | 46 ++++++++++++---- backends/vllm/src/lib.rs | 3 + 3 files changed, 123 insertions(+), 29 deletions(-) diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs index 46419279..15092d31 100644 --- a/backends/vllm/src/backend.rs +++ b/backends/vllm/src/backend.rs @@ -1,25 +1,81 @@ +use crate::engine::RequestOutput; use crate::errors::VllmBackendError; -use crate::{EngineArgs, LlmEngine}; +use crate::{EngineArgs, LlmEngine, STARTUP_INSTANT}; use async_trait::async_trait; -use crossbeam_channel::internal::SelectHandle; use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::hint::spin_loop; -use std::sync::atomic::AtomicBool; use std::sync::Arc; -use std::thread::{spawn, JoinHandle}; -use std::time::Duration; -use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use std::thread::spawn; +use std::time::{Duration, Instant as StdInstant, UNIX_EPOCH}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::{ ValidGenerateRequest, ValidParameters, ValidStoppingParameters, }; -use text_generation_router::Token; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use text_generation_router::{FinishReason, Token}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; type InferResult = Result; +impl TryFrom<&RequestOutput> for InferStreamResponse { + type Error = InferError; + + fn try_from(output: &RequestOutput) -> Result { + if let Some(last) = output.outputs.last() { + if let Some(token_id) = last.token_ids.last() { + let token = Token { + id: *token_id, + text: last.text.clone(), + // logprob: last.logprobs[0], + logprob: 0.0f32, + special: false, + }; + + if !output.finished { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) + } else { + // TODO: Let's see how to request metrics + // let metrics = output + // .metrics + // .last() + // .expect("metrics should be set if token was unpacked"); + // + // debug!("Request: {} -> {metrics:?}", &output.request_id); + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: last.text.clone(), + generated_tokens: last.token_ids.len() as u32, + finish_reason: last + .finish_reason + .as_ref() + .map(|reason| match reason.as_str() { + "length" => FinishReason::Length, + _ => FinishReason::StopSequence, + }) + .unwrap(), + seed: None, + }, + start: Instant::now(), + queued: Instant::now(), + }) + } + } else { + Err(InferError::GenerationError("No token returned".to_string())) + } + } else { + Err(InferError::GenerationError("No token returned".to_string())) + } + } +} + struct VllmRequestContext { tokens: Arc>, params: ValidParameters, @@ -35,7 +91,7 @@ impl VllmBackend { pub fn from_engine_args(args: EngineArgs) -> Result { let engine = LlmEngine::from_engine_args(args)?; let (sender, receiver) = unbounded(); - let looper = spawn(|| engine_background_loop(engine, receiver)); + let _ = spawn(|| engine_background_loop(engine, receiver)); Ok(Self { waiting_requests: sender, }) @@ -71,10 +127,7 @@ impl Backend for VllmBackend { } } -fn engine_background_loop( - mut engine: LlmEngine, - mut waiting_requests: Receiver, -) { +fn engine_background_loop(mut engine: LlmEngine, waiting_requests: Receiver) { info!("Starting vLLM engine background loop"); static DURATION_100_MS: Duration = Duration::from_millis(100); let mut in_flight_requests = HashMap::with_capacity(256); @@ -101,20 +154,32 @@ fn engine_background_loop( } } + // If there are tracked requests, let's pick the intermediate results if !in_flight_requests.is_empty() { match engine.step() { Ok(outputs) => outputs.iter().for_each(|output| { - let ctx = &in_flight_requests[&output.request_id]; + // Retrieve the context + { + let ctx = &in_flight_requests[&output.request_id]; + let result = InferStreamResponse::try_from(output); - // We only need to check on Err meaning the channel is not open anymore, so abort the request - if let Err(_) = ctx.stream.send(InferResult {}) { - debug!("Request {}'s channel dropped, aborting", &output.request_id); + // We only need to check on Err meaning the channel is not open anymore, so abort the request + if let Err(_) = ctx.stream.send(result) { + debug!("Request {}'s channel dropped, aborting", &output.request_id); + in_flight_requests.remove(&output.request_id); + engine.abort_request(&output.request_id); + } + } + + // Drop the request if done + if output.finished { in_flight_requests.remove(&output.request_id); - engine.abort_request(&output.request_id); } }), Err(err) => { error!("LLMEngine::step got an error: {err}"); + // TODO: Shall we exit from here? We can't link this to any particular user, + // it's Rust <> Python FFI which failed } } } diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index f3b6a761..dcbff82f 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -2,6 +2,7 @@ use crate::errors::VllmBackendError; use crate::{sampling_params, tokens_prompt, TryToPyObject}; use pyo3::intern; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; use tracing::{info, instrument}; @@ -36,6 +37,8 @@ impl IntoPyDict for EngineArgs { } } +static FINAL_OUTPUT_ONLY: GILOnceCell = GILOnceCell::new(); + pub struct SamplingParams<'a> { sampling_params: &'a ValidParameters, stopping_params: &'a ValidStoppingParameters, @@ -48,8 +51,10 @@ impl TryToPyObject for SamplingParams<'_> { let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( py, [ - (intern!(py, "seed"), self.sampling_params.seed.into_py(py)), + (intern!(py, "output_kind"), 2.into_py(py)), + (intern!(py, "logprobs"), 1.into_py(py)), (intern!(py, "n"), 1.into_py(py)), + (intern!(py, "seed"), self.sampling_params.seed.into_py(py)), (intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)), (intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)), ( @@ -86,20 +91,40 @@ impl TryToPyObject for SamplingParams<'_> { } #[derive(Debug)] -pub struct CompletionOutput { - pub index: usize, - pub text: String, // TODO: SmallString? - pub token_ids: Vec, // TODO: TinyVec? - pub logprobs: Option>, // TODO: TinyVec? +pub(crate) struct CompletionOutput { + pub token_ids: Vec, // TODO: TinyVec? + pub text: String, // TODO: SmallString? + // pub logprobs: Vec, // TODO: TinyVec? pub finish_reason: Option, // lora_request: LATER + pub index: usize, +} + +#[derive(Debug, Copy, Clone)] +pub(crate) struct RequestMetrics { + pub arrival_time: f32, + pub first_scheduled_time: f32, + pub first_token_time: f32, + pub time_in_queue: f32, +} + +impl<'py> FromPyObject<'py> for RequestMetrics { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + arrival_time: ob.getattr(intern!(py, "arrival_time"))?.extract()?, + first_scheduled_time: ob.getattr(intern!(py, "first_scheduled_time"))?.extract()?, + first_token_time: ob.getattr(intern!(py, "first_token_time"))?.extract()?, + time_in_queue: ob.getattr(intern!(py, "time_in_queue"))?.extract()?, + }) + } } #[derive(Debug)] -pub struct RequestOutput { - pub request_id: String, +pub(crate) struct RequestOutput { pub outputs: Vec, + // pub metrics: Vec, + pub request_id: String, pub finished: bool, - // metrics: Vec // TODO } impl<'py> FromPyObject<'py> for CompletionOutput { @@ -109,7 +134,7 @@ impl<'py> FromPyObject<'py> for CompletionOutput { index: ob.getattr(intern!(py, "index"))?.extract()?, text: ob.getattr(intern!(py, "text"))?.extract()?, token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?, - logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, + // logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?, }) } @@ -122,6 +147,7 @@ impl<'py> FromPyObject<'py> for RequestOutput { request_id: ob.getattr(intern!(py, "request_id"))?.extract()?, outputs: ob.getattr(intern!(py, "outputs"))?.extract()?, finished: ob.getattr(intern!(py, "finished"))?.extract()?, + // metrics: ob.getattr(intern!(py, "metrics"))?.extract()?, }) } } diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs index 12c910df..4bd4f434 100644 --- a/backends/vllm/src/lib.rs +++ b/backends/vllm/src/lib.rs @@ -9,6 +9,9 @@ use pyo3::prelude::PyAnyMethods; use pyo3::sync::GILOnceCell; use pyo3::types::PyModule; use pyo3::{Py, PyAny, PyErr, PyObject, Python}; +use tokio::time::Instant; + +pub(crate) const STARTUP_INSTANT: Instant = Instant::now(); static PY_TOKENS_PROMPT_CLASS: GILOnceCell> = GILOnceCell::new(); static PY_SAMPLING_PARAMS_CLASS: GILOnceCell> = GILOnceCell::new(); From 5452c1294c9614d3e60fbe66e5337bbdf1cbaa9c Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 31 Jan 2025 10:56:54 +0100 Subject: [PATCH 9/9] backend(vllm): disable metrics for now --- backends/vllm/src/backend.rs | 6 ++++++ backends/vllm/src/lib.rs | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs index 15092d31..16d7d4ac 100644 --- a/backends/vllm/src/backend.rs +++ b/backends/vllm/src/backend.rs @@ -63,6 +63,12 @@ impl TryFrom<&RequestOutput> for InferStreamResponse { .unwrap(), seed: None, }, + // start: STARTUP_INSTANT + // .checked_sub(Duration::from_secs_f32(metrics.first_scheduled_time)) + // .unwrap_or_else(Instant::now), + // queued: STARTUP_INSTANT + // .checked_sub(Duration::from_secs_f32(metrics.arrival_time)) + // .unwrap_or_else(Instant::now), start: Instant::now(), queued: Instant::now(), }) diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs index 4bd4f434..d0b44565 100644 --- a/backends/vllm/src/lib.rs +++ b/backends/vllm/src/lib.rs @@ -11,7 +11,7 @@ use pyo3::types::PyModule; use pyo3::{Py, PyAny, PyErr, PyObject, Python}; use tokio::time::Instant; -pub(crate) const STARTUP_INSTANT: Instant = Instant::now(); +pub(crate) static STARTUP_INSTANT: Instant = Instant::now(); static PY_TOKENS_PROMPT_CLASS: GILOnceCell> = GILOnceCell::new(); static PY_SAMPLING_PARAMS_CLASS: GILOnceCell> = GILOnceCell::new();