From bd2ec03d532c90c78c3cd242321cc10e9253d3de Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 22 Jan 2025 22:15:33 +0100 Subject: [PATCH] backend(vllm): statically allocate LLMEngine --- Cargo.lock | 84 ++++++++++++++++++++++----------- Cargo.toml | 2 +- backends/vllm/Cargo.toml | 8 +++- backends/vllm/src/backend.rs | 32 +++++++++++++ backends/vllm/src/engine.rs | 90 ++++++++++++++++++++++++++++++++++++ backends/vllm/src/errors.rs | 14 ++++++ backends/vllm/src/lib.rs | 6 +++ backends/vllm/src/main.rs | 19 ++++++-- router/Cargo.toml | 18 ++++---- 9 files changed, 229 insertions(+), 44 deletions(-) create mode 100644 backends/vllm/src/backend.rs create mode 100644 backends/vllm/src/engine.rs create mode 100644 backends/vllm/src/errors.rs create mode 100644 backends/vllm/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 0059976b..65cb0561 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -579,7 +579,7 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1601,7 +1601,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "ureq", ] @@ -2034,7 +2034,7 @@ checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17" dependencies = [ "opentelemetry 0.20.0", "opentelemetry-otlp", - "thiserror", + "thiserror 1.0.69", "tracing", "tracing-opentelemetry 0.21.0", ] @@ -2187,9 +2187,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.164" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libfuzzer-sys" @@ -2370,7 +2370,7 @@ dependencies = [ "metrics", "metrics-util", "quanta", - "thiserror", + "thiserror 1.0.69", "tokio", "tracing", ] @@ -2501,7 +2501,7 @@ dependencies = [ "futures", "pin-project", "rand", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-util", "tracing", @@ -2553,7 +2553,7 @@ dependencies = [ "rustls-pemfile", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-retry", "tokio-util", @@ -2857,7 +2857,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.69", "urlencoding", ] @@ -2875,7 +2875,7 @@ dependencies = [ "opentelemetry_api", "opentelemetry_sdk 0.20.0", "prost 0.11.9", - "thiserror", + "thiserror 1.0.69", "tokio", "tonic 0.9.2", ] @@ -2913,7 +2913,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.69", "urlencoding", ] @@ -2935,7 +2935,7 @@ dependencies = [ "rand", "regex", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-stream", ] @@ -2957,7 +2957,7 @@ dependencies = [ "ordered-float 4.5.0", "percent-encoding", "rand", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -3494,7 +3494,7 @@ dependencies = [ "rand_chacha", "simd_helpers", "system-deps", - "thiserror", + "thiserror 1.0.69", "v_frame", "wasm-bindgen", ] @@ -3571,7 +3571,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -4436,7 +4436,7 @@ dependencies = [ "pkg-config", "pyo3", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4446,6 +4446,14 @@ dependencies = [ [[package]] name = "text-generation-backends-vllm" version = "3.0.2-dev0" +dependencies = [ + "async-trait", + "pyo3", + "text-generation-router", + "thiserror 2.0.11", + "tokio", + "tokio-stream", +] [[package]] name = "text-generation-benchmark" @@ -4460,7 +4468,7 @@ dependencies = [ "serde_json", "tabled", "text-generation-client", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tracing", @@ -4477,7 +4485,7 @@ dependencies = [ "grpc-metadata", "prost 0.12.6", "prost-build", - "thiserror", + "thiserror 1.0.69", "tokio", "tonic 0.10.2", "tonic-build", @@ -4500,7 +4508,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tracing", "tracing-subscriber", "vergen", @@ -4542,7 +4550,7 @@ dependencies = [ "serde", "serde_json", "sysinfo", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4591,7 +4599,7 @@ dependencies = [ "serde_json", "slotmap", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4642,7 +4650,7 @@ dependencies = [ "serde_json", "slotmap", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4672,7 +4680,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl 2.0.11", ] [[package]] @@ -4686,6 +4703,17 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -4787,7 +4815,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror", + "thiserror 1.0.69", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -4795,9 +4823,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.1" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -4823,9 +4851,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 4183614e..170026a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } -pyo3 = { version = "0.23", features = ["auto-initialize"] } +pyo3 = { version = "0.22", features = ["auto-initialize"] } [profile.release] incremental = true diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index b738745e..bd1a21a1 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -6,5 +6,9 @@ authors.workspace = true homepage.workspace = true [dependencies] -pyo3 = "0.23" -pyo3-asyncio = "0.20" \ No newline at end of file +pyo3 = { workspace = true } +text-generation-router = { path = "../../router" } +thiserror = "2.0" +tokio = { version = "1.43", features = ["full"] } +tokio-stream = "0.1" +async-trait = "0.1.83" diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs new file mode 100644 index 00000000..6d49c268 --- /dev/null +++ b/backends/vllm/src/backend.rs @@ -0,0 +1,32 @@ +use crate::errors::VllmBackendError; +use crate::{EngineArgs, LlmEngine}; +use async_trait::async_trait; +use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use tokio_stream::wrappers::UnboundedReceiverStream; + +pub struct VllmBackend { + engine: LlmEngine, +} + +impl VllmBackend { + pub fn from_engine_args(args: EngineArgs) -> Result { + Ok(Self { + engine: LlmEngine::from_engine_args(args)?, + }) + } +} + +#[async_trait] +impl Backend for VllmBackend { + fn schedule( + &self, + _request: ValidGenerateRequest, + ) -> Result>, InferError> { + todo!() + } + + async fn health(&self, _current_health: bool) -> bool { + true + } +} diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs new file mode 100644 index 00000000..1debe4c5 --- /dev/null +++ b/backends/vllm/src/engine.rs @@ -0,0 +1,90 @@ +use crate::errors::VllmBackendError; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyDict, PyList}; + +pub struct EngineArgs { + pub model: String, + pub pipeline_parallel_size: u32, + pub tensor_parallel_size: u32, +} + +impl IntoPyDict for EngineArgs { + fn into_py_dict_bound(self, py: Python<'_>) -> Bound<'_, PyDict> { + PyDict::from_sequence_bound( + PyList::new_bound( + py, + [ + ("model", self.model.into_py(py)), + ( + "pipeline_parallel_size", + self.pipeline_parallel_size.into_py(py), + ), + ( + "tensor_parallel_size", + self.tensor_parallel_size.into_py(py), + ), + ], + ) + .as_any(), + ) + .expect("Failed to create Python Dict from EngineArgs") + } +} + +// impl IntoPy for EngineArgs { +// fn into_py(self, py: Python<'_>) -> PyObject { +// PyDict::from_sequence_bound( +// PyList::new_bound( +// py, +// [ +// ("model", self.model.into_py(py)), +// ( +// "pipeline_parallel_size", +// self.pipeline_parallel_size.into_py(py), +// ), +// ( +// "tensor_parallel_size", +// self.tensor_parallel_size.into_py(py), +// ), +// ], +// ) +// .as_any(), +// ) +// .expect("Failed to create Python Dict from EngineArgs") +// } +// } + +pub struct LlmEngine { + engine: PyObject, +} + +impl LlmEngine { + fn py_from_engine_args(args: EngineArgs) -> PyResult { + Python::with_gil(|py| { + // Create the EngineArgs from Rust + // from vllm.engine.arg_util import EngineArgs + // engine_args = EngineArgs(**args) + let py_engine_args_mod = PyModule::import_bound(py, "vllm.engine.arg_utils")?; + let py_engine_args_class = py_engine_args_mod.getattr("EngineArgs")?; + let py_engine_args = + py_engine_args_class.call((), Some(&args.into_py_dict_bound(py)))?; + + // Next create the LLMEngine from the EngineArgs + // from vllm.engine.llm_engine import LLMEngine + // engine = LLMEngine.from_engine_args(engine_args) + let py_engine_llm_mod = PyModule::import_bound(py, "vllm.engine.llm_engine")?; + let py_engine_llm_class = py_engine_llm_mod.getattr("LLMEngine")?; + py_engine_llm_class + .call_method("from_engine_args", (py_engine_args,), None)? + .extract() + }) + } + + pub fn from_engine_args(args: EngineArgs) -> Result { + let engine = Self::py_from_engine_args(args)?; + + Ok(Self { engine }) + } + + pub fn step(&mut self) {} +} diff --git a/backends/vllm/src/errors.rs b/backends/vllm/src/errors.rs new file mode 100644 index 00000000..fa7d4414 --- /dev/null +++ b/backends/vllm/src/errors.rs @@ -0,0 +1,14 @@ +use pyo3::PyErr; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum VllmBackendError { + #[error("{0}")] + Python(PyErr), +} + +impl From for VllmBackendError { + fn from(value: PyErr) -> Self { + Self::Python(value) + } +} diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs new file mode 100644 index 00000000..5ab448fd --- /dev/null +++ b/backends/vllm/src/lib.rs @@ -0,0 +1,6 @@ +mod backend; +mod engine; +mod errors; + +pub use backend::VllmBackend; +pub use engine::{EngineArgs, LlmEngine}; diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs index fd54d8b1..66597247 100644 --- a/backends/vllm/src/main.rs +++ b/backends/vllm/src/main.rs @@ -1,6 +1,17 @@ -use pyo3::prelude::*; +use text_generation_backends_vllm::{EngineArgs, LlmEngine}; -#[pyo3_asyncio::tokio::main(flavor = "multi_thread")] -async fn main() { - println!("Hello, world!"); +#[tokio::main] +async fn main() -> Result<(), ()> { + let args = EngineArgs { + model: String::from("meta-llama/Llama-3.2-1B-Instruct"), + pipeline_parallel_size: 1, + tensor_parallel_size: 1, + }; + + match LlmEngine::from_engine_args(args) { + Ok(_) => println!("Engine successfully allocated"), + Err(err) => println!("Got an error: {}", err), + } + + Ok(()) } diff --git a/router/Cargo.toml b/router/Cargo.toml index 2e621dfc..81c58616 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -31,11 +31,11 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } @@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true } minijinja-contrib = { workspace = true } @@ -57,9 +57,9 @@ image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = [ - "v4", - "fast-rng", - "macro-diagnostics", + "v4", + "fast-rng", + "macro-diagnostics", ] } csv = "1.3.0" ureq = "=2.9"