mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
backend(vllm): statically allocate LLMEngine
This commit is contained in:
parent
cfd22726c9
commit
bd2ec03d53
84
Cargo.lock
generated
84
Cargo.lock
generated
@ -579,7 +579,7 @@ dependencies = [
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1601,7 +1601,7 @@ dependencies = [
|
||||
"reqwest 0.11.27",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"ureq",
|
||||
]
|
||||
@ -2034,7 +2034,7 @@ checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17"
|
||||
dependencies = [
|
||||
"opentelemetry 0.20.0",
|
||||
"opentelemetry-otlp",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
"tracing-opentelemetry 0.21.0",
|
||||
]
|
||||
@ -2187,9 +2187,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.164"
|
||||
version = "0.2.169"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
|
||||
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
|
||||
|
||||
[[package]]
|
||||
name = "libfuzzer-sys"
|
||||
@ -2370,7 +2370,7 @@ dependencies = [
|
||||
"metrics",
|
||||
"metrics-util",
|
||||
"quanta",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
@ -2501,7 +2501,7 @@ dependencies = [
|
||||
"futures",
|
||||
"pin-project",
|
||||
"rand",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
@ -2553,7 +2553,7 @@ dependencies = [
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-retry",
|
||||
"tokio-util",
|
||||
@ -2857,7 +2857,7 @@ dependencies = [
|
||||
"js-sys",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
@ -2875,7 +2875,7 @@ dependencies = [
|
||||
"opentelemetry_api",
|
||||
"opentelemetry_sdk 0.20.0",
|
||||
"prost 0.11.9",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tonic 0.9.2",
|
||||
]
|
||||
@ -2913,7 +2913,7 @@ dependencies = [
|
||||
"js-sys",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
@ -2935,7 +2935,7 @@ dependencies = [
|
||||
"rand",
|
||||
"regex",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
@ -2957,7 +2957,7 @@ dependencies = [
|
||||
"ordered-float 4.5.0",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3494,7 +3494,7 @@ dependencies = [
|
||||
"rand_chacha",
|
||||
"simd_helpers",
|
||||
"system-deps",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"v_frame",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
@ -3571,7 +3571,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"libredox",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4436,7 +4436,7 @@ dependencies = [
|
||||
"pkg-config",
|
||||
"pyo3",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@ -4446,6 +4446,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "text-generation-backends-vllm"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"pyo3",
|
||||
"text-generation-router",
|
||||
"thiserror 2.0.11",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
@ -4460,7 +4468,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tabled",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@ -4477,7 +4485,7 @@ dependencies = [
|
||||
"grpc-metadata",
|
||||
"prost 0.12.6",
|
||||
"prost-build",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tonic 0.10.2",
|
||||
"tonic-build",
|
||||
@ -4500,7 +4508,7 @@ dependencies = [
|
||||
"reqwest 0.11.27",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"vergen",
|
||||
@ -4542,7 +4550,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sysinfo",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@ -4591,7 +4599,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"slotmap",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@ -4642,7 +4650,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"slotmap",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@ -4672,7 +4680,16 @@ version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
"thiserror-impl 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
|
||||
dependencies = [
|
||||
"thiserror-impl 2.0.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4686,6 +4703,17 @@ dependencies = [
|
||||
"syn 2.0.89",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.89",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.8"
|
||||
@ -4787,7 +4815,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
"thiserror",
|
||||
"thiserror 1.0.69",
|
||||
"unicode-normalization-alignments",
|
||||
"unicode-segmentation",
|
||||
"unicode_categories",
|
||||
@ -4795,9 +4823,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.41.1"
|
||||
version = "1.43.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33"
|
||||
checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"bytes",
|
||||
@ -4823,9 +4851,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-macros"
|
||||
version = "2.4.0"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
|
||||
checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -36,7 +36,7 @@ metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
pyo3 = { version = "0.23", features = ["auto-initialize"] }
|
||||
pyo3 = { version = "0.22", features = ["auto-initialize"] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
@ -6,5 +6,9 @@ authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
pyo3 = "0.23"
|
||||
pyo3-asyncio = "0.20"
|
||||
pyo3 = { workspace = true }
|
||||
text-generation-router = { path = "../../router" }
|
||||
thiserror = "2.0"
|
||||
tokio = { version = "1.43", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
async-trait = "0.1.83"
|
||||
|
32
backends/vllm/src/backend.rs
Normal file
32
backends/vllm/src/backend.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use crate::errors::VllmBackendError;
|
||||
use crate::{EngineArgs, LlmEngine};
|
||||
use async_trait::async_trait;
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
pub struct VllmBackend {
|
||||
engine: LlmEngine,
|
||||
}
|
||||
|
||||
impl VllmBackend {
|
||||
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
||||
Ok(Self {
|
||||
engine: LlmEngine::from_engine_args(args)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for VllmBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
_request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn health(&self, _current_health: bool) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
90
backends/vllm/src/engine.rs
Normal file
90
backends/vllm/src/engine.rs
Normal file
@ -0,0 +1,90 @@
|
||||
use crate::errors::VllmBackendError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyList};
|
||||
|
||||
pub struct EngineArgs {
|
||||
pub model: String,
|
||||
pub pipeline_parallel_size: u32,
|
||||
pub tensor_parallel_size: u32,
|
||||
}
|
||||
|
||||
impl IntoPyDict for EngineArgs {
|
||||
fn into_py_dict_bound(self, py: Python<'_>) -> Bound<'_, PyDict> {
|
||||
PyDict::from_sequence_bound(
|
||||
PyList::new_bound(
|
||||
py,
|
||||
[
|
||||
("model", self.model.into_py(py)),
|
||||
(
|
||||
"pipeline_parallel_size",
|
||||
self.pipeline_parallel_size.into_py(py),
|
||||
),
|
||||
(
|
||||
"tensor_parallel_size",
|
||||
self.tensor_parallel_size.into_py(py),
|
||||
),
|
||||
],
|
||||
)
|
||||
.as_any(),
|
||||
)
|
||||
.expect("Failed to create Python Dict from EngineArgs")
|
||||
}
|
||||
}
|
||||
|
||||
// impl IntoPy<PyObject> for EngineArgs {
|
||||
// fn into_py(self, py: Python<'_>) -> PyObject {
|
||||
// PyDict::from_sequence_bound(
|
||||
// PyList::new_bound(
|
||||
// py,
|
||||
// [
|
||||
// ("model", self.model.into_py(py)),
|
||||
// (
|
||||
// "pipeline_parallel_size",
|
||||
// self.pipeline_parallel_size.into_py(py),
|
||||
// ),
|
||||
// (
|
||||
// "tensor_parallel_size",
|
||||
// self.tensor_parallel_size.into_py(py),
|
||||
// ),
|
||||
// ],
|
||||
// )
|
||||
// .as_any(),
|
||||
// )
|
||||
// .expect("Failed to create Python Dict from EngineArgs")
|
||||
// }
|
||||
// }
|
||||
|
||||
pub struct LlmEngine {
|
||||
engine: PyObject,
|
||||
}
|
||||
|
||||
impl LlmEngine {
|
||||
fn py_from_engine_args(args: EngineArgs) -> PyResult<PyObject> {
|
||||
Python::with_gil(|py| {
|
||||
// Create the EngineArgs from Rust
|
||||
// from vllm.engine.arg_util import EngineArgs
|
||||
// engine_args = EngineArgs(**args)
|
||||
let py_engine_args_mod = PyModule::import_bound(py, "vllm.engine.arg_utils")?;
|
||||
let py_engine_args_class = py_engine_args_mod.getattr("EngineArgs")?;
|
||||
let py_engine_args =
|
||||
py_engine_args_class.call((), Some(&args.into_py_dict_bound(py)))?;
|
||||
|
||||
// Next create the LLMEngine from the EngineArgs
|
||||
// from vllm.engine.llm_engine import LLMEngine
|
||||
// engine = LLMEngine.from_engine_args(engine_args)
|
||||
let py_engine_llm_mod = PyModule::import_bound(py, "vllm.engine.llm_engine")?;
|
||||
let py_engine_llm_class = py_engine_llm_mod.getattr("LLMEngine")?;
|
||||
py_engine_llm_class
|
||||
.call_method("from_engine_args", (py_engine_args,), None)?
|
||||
.extract()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
|
||||
let engine = Self::py_from_engine_args(args)?;
|
||||
|
||||
Ok(Self { engine })
|
||||
}
|
||||
|
||||
pub fn step(&mut self) {}
|
||||
}
|
14
backends/vllm/src/errors.rs
Normal file
14
backends/vllm/src/errors.rs
Normal file
@ -0,0 +1,14 @@
|
||||
use pyo3::PyErr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum VllmBackendError {
|
||||
#[error("{0}")]
|
||||
Python(PyErr),
|
||||
}
|
||||
|
||||
impl From<PyErr> for VllmBackendError {
|
||||
fn from(value: PyErr) -> Self {
|
||||
Self::Python(value)
|
||||
}
|
||||
}
|
6
backends/vllm/src/lib.rs
Normal file
6
backends/vllm/src/lib.rs
Normal file
@ -0,0 +1,6 @@
|
||||
mod backend;
|
||||
mod engine;
|
||||
mod errors;
|
||||
|
||||
pub use backend::VllmBackend;
|
||||
pub use engine::{EngineArgs, LlmEngine};
|
@ -1,6 +1,17 @@
|
||||
use pyo3::prelude::*;
|
||||
use text_generation_backends_vllm::{EngineArgs, LlmEngine};
|
||||
|
||||
#[pyo3_asyncio::tokio::main(flavor = "multi_thread")]
|
||||
async fn main() {
|
||||
println!("Hello, world!");
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), ()> {
|
||||
let args = EngineArgs {
|
||||
model: String::from("meta-llama/Llama-3.2-1B-Instruct"),
|
||||
pipeline_parallel_size: 1,
|
||||
tensor_parallel_size: 1,
|
||||
};
|
||||
|
||||
match LlmEngine::from_engine_args(args) {
|
||||
Ok(_) => println!("Engine successfully allocated"),
|
||||
Err(err) => println!("Got an error: {}", err),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user