From a7c2a470d67694acfc86c40fd913f5a4df966740 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 27 Jan 2025 22:39:35 +0100 Subject: [PATCH] backend(vllm): submit new request to vLLM engine --- Cargo.lock | 4 ++ backends/vllm/Cargo.toml | 4 ++ backends/vllm/src/backend.rs | 80 ++++++++++++++++++++-- backends/vllm/src/engine.rs | 125 ++++++++++++++++++++++++++++------- backends/vllm/src/errors.rs | 7 ++ backends/vllm/src/lib.rs | 33 +++++++++ backends/vllm/src/main.rs | 2 + 7 files changed, 227 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 80ea70bd..eb922162 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4449,11 +4449,15 @@ version = "3.0.2-dev0" dependencies = [ "async-trait", "clap 4.5.21", + "log", "pyo3", "text-generation-router", "thiserror 2.0.11", "tokio", "tokio-stream", + "tracing", + "tracing-subscriber", + "uuid", ] [[package]] diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index 0ab22b47..2308a655 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -13,3 +13,7 @@ text-generation-router = { path = "../../router" } thiserror = "2.0" tokio = { version = "1.43", features = ["full"] } tokio-stream = "0.1" +uuid = { version = "1.11.0", features = ["v4"] } +log = "0.4.22" +tracing = "0.1.40" +tracing-subscriber = "0.3.18" diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs index 6d49c268..0ccf8063 100644 --- a/backends/vllm/src/backend.rs +++ b/backends/vllm/src/backend.rs @@ -1,18 +1,39 @@ use crate::errors::VllmBackendError; use crate::{EngineArgs, LlmEngine}; use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use std::thread::{spawn, JoinHandle}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; -use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::validation::{ + ValidGenerateRequest, ValidParameters, ValidStoppingParameters, +}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{debug, info, warn}; + +type InferResult = Result; + +struct Request { + tokens: Arc>, + params: ValidParameters, + stopping_params: ValidStoppingParameters, + streamer: UnboundedSender, +} pub struct VllmBackend { - engine: LlmEngine, + looper: JoinHandle<()>, + waiting_requests: UnboundedSender, } impl VllmBackend { pub fn from_engine_args(args: EngineArgs) -> Result { + let engine = LlmEngine::from_engine_args(args)?; + let (sender, receiver) = unbounded_channel(); + let looper = spawn(|| engine_background_loop(engine, receiver)); Ok(Self { - engine: LlmEngine::from_engine_args(args)?, + looper, + waiting_requests: sender, }) } } @@ -21,12 +42,61 @@ impl VllmBackend { impl Backend for VllmBackend { fn schedule( &self, - _request: ValidGenerateRequest, + request: ValidGenerateRequest, ) -> Result>, InferError> { - todo!() + let (sender, receiver) = unbounded_channel(); + + // Send the query to the vLLM Engine + if let Some(input_ids) = request.input_ids { + debug!("Attempt to queue new request"); + if let Err(err) = self.waiting_requests.send(Request { + tokens: Arc::clone(&input_ids), + params: request.parameters, + stopping_params: request.stopping_parameters, + streamer: sender, + }) { + warn!("Waiting Requests queue has been closed: {err}") + } + }; + + Ok(UnboundedReceiverStream::new(receiver)) } async fn health(&self, _current_health: bool) -> bool { true } } + +fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver) { + info!("Starting vLLM engine background loop"); + + let mut in_flight_requests = HashMap::with_capacity(256); + loop { + if !waiting_requests.is_empty() { + let num_waiting_requests = waiting_requests.len(); + debug!( + "Adding {} requests to the vLLM engine", + num_waiting_requests + ); + + let mut requests = Vec::with_capacity(num_waiting_requests); + waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests); + + for request in requests { + match engine.add_request(&request.tokens, &request.params, &request.stopping_params) + { + Ok(request_id) => { + debug!("Successfully scheduled request {request_id}"); + in_flight_requests.insert(request_id.to_string(), request); + } + Err(err) => { + warn!("Failed to schedule new request: {err}"); + } + } + } + } + engine.step(); + } + + info!("Shutting down vLLM engine background loop"); +} diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index 1debe4c5..d4f4f5dc 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -1,6 +1,10 @@ use crate::errors::VllmBackendError; +use crate::{sampling_params, tokens_prompt, TryToPyObject}; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyDict, PyList}; +use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; +use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; +use tracing::info; +use uuid::Uuid; pub struct EngineArgs { pub model: String, @@ -31,28 +35,51 @@ impl IntoPyDict for EngineArgs { } } -// impl IntoPy for EngineArgs { -// fn into_py(self, py: Python<'_>) -> PyObject { -// PyDict::from_sequence_bound( -// PyList::new_bound( -// py, -// [ -// ("model", self.model.into_py(py)), -// ( -// "pipeline_parallel_size", -// self.pipeline_parallel_size.into_py(py), -// ), -// ( -// "tensor_parallel_size", -// self.tensor_parallel_size.into_py(py), -// ), -// ], -// ) -// .as_any(), -// ) -// .expect("Failed to create Python Dict from EngineArgs") -// } -// } +pub struct SamplingParams<'a> { + sampling_params: &'a ValidParameters, + stopping_params: &'a ValidStoppingParameters, +} + +impl TryToPyObject for SamplingParams<'_> { + fn try_to_object(&self, py: Python<'_>) -> Result { + let py_sampling_params_class = sampling_params(py); + + let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( + py, + [ + ("seed", self.sampling_params.seed.into_py(py)), + ("n", 1.into_py(py)), + ("top_k", self.sampling_params.top_k.into_py(py)), + ("top_p", self.sampling_params.top_p.into_py(py)), + ("temperature", self.sampling_params.temperature.into_py(py)), + ( + "frequency_penalty", + self.sampling_params.frequency_penalty.into_py(py), + ), + ( + "repetition_penalty", + self.sampling_params.repetition_penalty.into_py(py), + ), + ( + "ignore_eos", + self.stopping_params.ignore_eos_token.into_py(py), + ), + ( + "max_tokens", + self.stopping_params.max_new_tokens.into_py(py), + ), + ( + "stop", + PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(), + ), + ], + )); + + Ok(py_sampling_params_class + .call_method_bound(py, "from_optional", (), Some(&kwargs?))? + .to_object(py)) + } +} pub struct LlmEngine { engine: PyObject, @@ -80,11 +107,63 @@ impl LlmEngine { }) } + fn py_add_request( + &self, + request_id: &str, + prompt: &[u32], + sampling_params: SamplingParams, + ) -> Result<(), VllmBackendError> { + Python::with_gil(|py| { + // Create vllm.Tokens + let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py); + let py_tokens_prompt_class = tokens_prompt(py); + let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?; + let py_sampling_params = sampling_params.try_to_object(py)?; + + let _ = py.eval_bound( + "print(type(params), params)", + Some(&[("params", &py_sampling_params)].into_py_dict_bound(py)), + None, + ); + + self.engine.call_method1( + py, + "add_request", + ( + PyString::new_bound(py, request_id), + py_tokens_prompt, + py_sampling_params, + ), + )?; + + self.engine.call_method0(py, "step") + })?; + + Ok(()) + } + pub fn from_engine_args(args: EngineArgs) -> Result { let engine = Self::py_from_engine_args(args)?; Ok(Self { engine }) } + pub fn add_request( + &self, + prompt: &[u32], + sampling_params: &ValidParameters, + stopping_params: &ValidStoppingParameters, + ) -> Result { + let request_id = Uuid::new_v4(); + let sampling_params = SamplingParams { + sampling_params, + stopping_params, + }; + self.py_add_request(&request_id.to_string(), prompt, sampling_params)?; + + info!("Submitted new request: {request_id}"); + Ok(request_id) + } + pub fn step(&mut self) {} } diff --git a/backends/vllm/src/errors.rs b/backends/vllm/src/errors.rs index aa008190..1b03f5a4 100644 --- a/backends/vllm/src/errors.rs +++ b/backends/vllm/src/errors.rs @@ -1,4 +1,5 @@ use pyo3::PyErr; +use text_generation_router::infer::InferError; use text_generation_router::server::WebServerError; use thiserror::Error; @@ -22,3 +23,9 @@ impl From for VllmBackendError { Self::WebServer(value) } } + +impl From for InferError { + fn from(value: VllmBackendError) -> Self { + InferError::GenerationError(value.to_string()) + } +} diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs index 37d5eb25..12c910df 100644 --- a/backends/vllm/src/lib.rs +++ b/backends/vllm/src/lib.rs @@ -5,3 +5,36 @@ mod errors; pub use backend::VllmBackend; pub use engine::{EngineArgs, LlmEngine}; pub use errors::VllmBackendError; +use pyo3::prelude::PyAnyMethods; +use pyo3::sync::GILOnceCell; +use pyo3::types::PyModule; +use pyo3::{Py, PyAny, PyErr, PyObject, Python}; + +static PY_TOKENS_PROMPT_CLASS: GILOnceCell> = GILOnceCell::new(); +static PY_SAMPLING_PARAMS_CLASS: GILOnceCell> = GILOnceCell::new(); + +#[inline] +pub(crate) fn tokens_prompt(py: Python) -> &Py { + PY_TOKENS_PROMPT_CLASS.get_or_init(py, || { + PyModule::import_bound(py, "vllm.inputs") + .expect("Failed to import vllm.inputs") + .getattr("TokensPrompt") + .expect("Failed to import vllm.inputs.TokensPrompt") + .unbind() + }) +} + +#[inline] +pub(crate) fn sampling_params(py: Python) -> &Py { + PY_SAMPLING_PARAMS_CLASS.get_or_init(py, || { + PyModule::import_bound(py, "vllm") + .expect("Failed to import vllm") + .getattr("SamplingParams") + .expect("Failed to import vllm.SamplingParams") + .unbind() + }) +} + +pub(crate) trait TryToPyObject { + fn try_to_object(&self, py: Python<'_>) -> Result; +} diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs index 20b7efc9..55f47871 100644 --- a/backends/vllm/src/main.rs +++ b/backends/vllm/src/main.rs @@ -73,6 +73,8 @@ impl Into for &Args { #[tokio::main] async fn main() -> Result<(), VllmBackendError> { + tracing_subscriber::fmt::init(); + let args = Args::parse(); let backend = VllmBackend::from_engine_args((&args).into())?;