backend(vllm): statically allocate LLMEngine

This commit is contained in:
Morgan Funtowicz 2025-01-22 22:15:33 +01:00
parent cfd22726c9
commit bd2ec03d53
9 changed files with 229 additions and 44 deletions

84
Cargo.lock generated
View File

@ -579,7 +579,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"thiserror",
"thiserror 1.0.69",
]
[[package]]
@ -1601,7 +1601,7 @@ dependencies = [
"reqwest 0.11.27",
"serde",
"serde_json",
"thiserror",
"thiserror 1.0.69",
"tokio",
"ureq",
]
@ -2034,7 +2034,7 @@ checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17"
dependencies = [
"opentelemetry 0.20.0",
"opentelemetry-otlp",
"thiserror",
"thiserror 1.0.69",
"tracing",
"tracing-opentelemetry 0.21.0",
]
@ -2187,9 +2187,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
version = "0.2.164"
version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libfuzzer-sys"
@ -2370,7 +2370,7 @@ dependencies = [
"metrics",
"metrics-util",
"quanta",
"thiserror",
"thiserror 1.0.69",
"tokio",
"tracing",
]
@ -2501,7 +2501,7 @@ dependencies = [
"futures",
"pin-project",
"rand",
"thiserror",
"thiserror 1.0.69",
"tokio",
"tokio-util",
"tracing",
@ -2553,7 +2553,7 @@ dependencies = [
"rustls-pemfile",
"serde",
"serde_json",
"thiserror",
"thiserror 1.0.69",
"tokio",
"tokio-retry",
"tokio-util",
@ -2857,7 +2857,7 @@ dependencies = [
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
"thiserror 1.0.69",
"urlencoding",
]
@ -2875,7 +2875,7 @@ dependencies = [
"opentelemetry_api",
"opentelemetry_sdk 0.20.0",
"prost 0.11.9",
"thiserror",
"thiserror 1.0.69",
"tokio",
"tonic 0.9.2",
]
@ -2913,7 +2913,7 @@ dependencies = [
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
"thiserror 1.0.69",
"urlencoding",
]
@ -2935,7 +2935,7 @@ dependencies = [
"rand",
"regex",
"serde_json",
"thiserror",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
]
@ -2957,7 +2957,7 @@ dependencies = [
"ordered-float 4.5.0",
"percent-encoding",
"rand",
"thiserror",
"thiserror 1.0.69",
]
[[package]]
@ -3494,7 +3494,7 @@ dependencies = [
"rand_chacha",
"simd_helpers",
"system-deps",
"thiserror",
"thiserror 1.0.69",
"v_frame",
"wasm-bindgen",
]
@ -3571,7 +3571,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
"getrandom",
"libredox",
"thiserror",
"thiserror 1.0.69",
]
[[package]]
@ -4436,7 +4436,7 @@ dependencies = [
"pkg-config",
"pyo3",
"text-generation-router",
"thiserror",
"thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@ -4446,6 +4446,14 @@ dependencies = [
[[package]]
name = "text-generation-backends-vllm"
version = "3.0.2-dev0"
dependencies = [
"async-trait",
"pyo3",
"text-generation-router",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
]
[[package]]
name = "text-generation-benchmark"
@ -4460,7 +4468,7 @@ dependencies = [
"serde_json",
"tabled",
"text-generation-client",
"thiserror",
"thiserror 1.0.69",
"tokenizers",
"tokio",
"tracing",
@ -4477,7 +4485,7 @@ dependencies = [
"grpc-metadata",
"prost 0.12.6",
"prost-build",
"thiserror",
"thiserror 1.0.69",
"tokio",
"tonic 0.10.2",
"tonic-build",
@ -4500,7 +4508,7 @@ dependencies = [
"reqwest 0.11.27",
"serde",
"serde_json",
"thiserror",
"thiserror 1.0.69",
"tracing",
"tracing-subscriber",
"vergen",
@ -4542,7 +4550,7 @@ dependencies = [
"serde",
"serde_json",
"sysinfo",
"thiserror",
"thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@ -4591,7 +4599,7 @@ dependencies = [
"serde_json",
"slotmap",
"text-generation-router",
"thiserror",
"thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@ -4642,7 +4650,7 @@ dependencies = [
"serde_json",
"slotmap",
"text-generation-router",
"thiserror",
"thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@ -4672,7 +4680,16 @@ version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
dependencies = [
"thiserror-impl",
"thiserror-impl 1.0.69",
]
[[package]]
name = "thiserror"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
dependencies = [
"thiserror-impl 2.0.11",
]
[[package]]
@ -4686,6 +4703,17 @@ dependencies = [
"syn 2.0.89",
]
[[package]]
name = "thiserror-impl"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.89",
]
[[package]]
name = "thread_local"
version = "1.1.8"
@ -4787,7 +4815,7 @@ dependencies = [
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"thiserror 1.0.69",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
@ -4795,9 +4823,9 @@ dependencies = [
[[package]]
name = "tokio"
version = "1.41.1"
version = "1.43.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33"
checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
dependencies = [
"backtrace",
"bytes",
@ -4823,9 +4851,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.4.0"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [
"proc-macro2",
"quote",

View File

@ -36,7 +36,7 @@ metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.23", features = ["auto-initialize"] }
pyo3 = { version = "0.22", features = ["auto-initialize"] }
[profile.release]
incremental = true

View File

@ -6,5 +6,9 @@ authors.workspace = true
homepage.workspace = true
[dependencies]
pyo3 = "0.23"
pyo3-asyncio = "0.20"
pyo3 = { workspace = true }
text-generation-router = { path = "../../router" }
thiserror = "2.0"
tokio = { version = "1.43", features = ["full"] }
tokio-stream = "0.1"
async-trait = "0.1.83"

View File

@ -0,0 +1,32 @@
use crate::errors::VllmBackendError;
use crate::{EngineArgs, LlmEngine};
use async_trait::async_trait;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use tokio_stream::wrappers::UnboundedReceiverStream;
pub struct VllmBackend {
engine: LlmEngine,
}
impl VllmBackend {
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
Ok(Self {
engine: LlmEngine::from_engine_args(args)?,
})
}
}
#[async_trait]
impl Backend for VllmBackend {
fn schedule(
&self,
_request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
todo!()
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -0,0 +1,90 @@
use crate::errors::VllmBackendError;
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyList};
pub struct EngineArgs {
pub model: String,
pub pipeline_parallel_size: u32,
pub tensor_parallel_size: u32,
}
impl IntoPyDict for EngineArgs {
fn into_py_dict_bound(self, py: Python<'_>) -> Bound<'_, PyDict> {
PyDict::from_sequence_bound(
PyList::new_bound(
py,
[
("model", self.model.into_py(py)),
(
"pipeline_parallel_size",
self.pipeline_parallel_size.into_py(py),
),
(
"tensor_parallel_size",
self.tensor_parallel_size.into_py(py),
),
],
)
.as_any(),
)
.expect("Failed to create Python Dict from EngineArgs")
}
}
// impl IntoPy<PyObject> for EngineArgs {
// fn into_py(self, py: Python<'_>) -> PyObject {
// PyDict::from_sequence_bound(
// PyList::new_bound(
// py,
// [
// ("model", self.model.into_py(py)),
// (
// "pipeline_parallel_size",
// self.pipeline_parallel_size.into_py(py),
// ),
// (
// "tensor_parallel_size",
// self.tensor_parallel_size.into_py(py),
// ),
// ],
// )
// .as_any(),
// )
// .expect("Failed to create Python Dict from EngineArgs")
// }
// }
pub struct LlmEngine {
engine: PyObject,
}
impl LlmEngine {
fn py_from_engine_args(args: EngineArgs) -> PyResult<PyObject> {
Python::with_gil(|py| {
// Create the EngineArgs from Rust
// from vllm.engine.arg_util import EngineArgs
// engine_args = EngineArgs(**args)
let py_engine_args_mod = PyModule::import_bound(py, "vllm.engine.arg_utils")?;
let py_engine_args_class = py_engine_args_mod.getattr("EngineArgs")?;
let py_engine_args =
py_engine_args_class.call((), Some(&args.into_py_dict_bound(py)))?;
// Next create the LLMEngine from the EngineArgs
// from vllm.engine.llm_engine import LLMEngine
// engine = LLMEngine.from_engine_args(engine_args)
let py_engine_llm_mod = PyModule::import_bound(py, "vllm.engine.llm_engine")?;
let py_engine_llm_class = py_engine_llm_mod.getattr("LLMEngine")?;
py_engine_llm_class
.call_method("from_engine_args", (py_engine_args,), None)?
.extract()
})
}
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
let engine = Self::py_from_engine_args(args)?;
Ok(Self { engine })
}
pub fn step(&mut self) {}
}

View File

@ -0,0 +1,14 @@
use pyo3::PyErr;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum VllmBackendError {
#[error("{0}")]
Python(PyErr),
}
impl From<PyErr> for VllmBackendError {
fn from(value: PyErr) -> Self {
Self::Python(value)
}
}

6
backends/vllm/src/lib.rs Normal file
View File

@ -0,0 +1,6 @@
mod backend;
mod engine;
mod errors;
pub use backend::VllmBackend;
pub use engine::{EngineArgs, LlmEngine};

View File

@ -1,6 +1,17 @@
use pyo3::prelude::*;
use text_generation_backends_vllm::{EngineArgs, LlmEngine};
#[pyo3_asyncio::tokio::main(flavor = "multi_thread")]
async fn main() {
println!("Hello, world!");
#[tokio::main]
async fn main() -> Result<(), ()> {
let args = EngineArgs {
model: String::from("meta-llama/Llama-3.2-1B-Instruct"),
pipeline_parallel_size: 1,
tensor_parallel_size: 1,
};
match LlmEngine::from_engine_args(args) {
Ok(_) => println!("Engine successfully allocated"),
Err(err) => println!("Got an error: {}", err),
}
Ok(())
}

View File

@ -31,11 +31,11 @@ serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
"opentelemetry-otlp",
] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
@ -57,9 +57,9 @@ image = "0.25.1"
base64 = { workspace = true }
sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = [
"v4",
"fast-rng",
"macro-diagnostics",
"v4",
"fast-rng",
"macro-diagnostics",
] }
csv = "1.3.0"
ureq = "=2.9"