From 32dffcff60ad4a5a75cb5ed356a077d0d7a7e2c3 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 30 Jan 2025 13:35:21 +0100 Subject: [PATCH] backend(vllm): expose FFI for CompletionOutput and RequestOutput on Rust side --- Cargo.lock | 1 + backends/vllm/Cargo.toml | 1 + backends/vllm/src/backend.rs | 80 +++++++++++++++++++----------- backends/vllm/src/engine.rs | 94 +++++++++++++++++++++++++++++------- 4 files changed, 131 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb922162..70ecf1f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4449,6 +4449,7 @@ version = "3.0.2-dev0" dependencies = [ "async-trait", "clap 4.5.21", + "crossbeam-channel", "log", "pyo3", "text-generation-router", diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml index 2308a655..c77f4562 100644 --- a/backends/vllm/Cargo.toml +++ b/backends/vllm/Cargo.toml @@ -8,6 +8,7 @@ homepage.workspace = true [dependencies] async-trait = "0.1.83" clap = { version = "4.5.21", features = ["derive"] } +crossbeam-channel = "0.5" pyo3 = { workspace = true } text-generation-router = { path = "../../router" } thiserror = "2.0" diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs index 0ccf8063..46419279 100644 --- a/backends/vllm/src/backend.rs +++ b/backends/vllm/src/backend.rs @@ -1,38 +1,42 @@ use crate::errors::VllmBackendError; use crate::{EngineArgs, LlmEngine}; use async_trait::async_trait; -use std::collections::HashMap; +use crossbeam_channel::internal::SelectHandle; +use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender}; +use std::collections::{HashMap, HashSet}; +use std::hint::spin_loop; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::thread::{spawn, JoinHandle}; +use std::time::Duration; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::validation::{ ValidGenerateRequest, ValidParameters, ValidStoppingParameters, }; +use text_generation_router::Token; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, info, warn}; +use tracing::{debug, error, info, warn}; type InferResult = Result; -struct Request { +struct VllmRequestContext { tokens: Arc>, params: ValidParameters, stopping_params: ValidStoppingParameters, - streamer: UnboundedSender, + stream: UnboundedSender, } pub struct VllmBackend { - looper: JoinHandle<()>, - waiting_requests: UnboundedSender, + waiting_requests: Sender, } impl VllmBackend { pub fn from_engine_args(args: EngineArgs) -> Result { let engine = LlmEngine::from_engine_args(args)?; - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = unbounded(); let looper = spawn(|| engine_background_loop(engine, receiver)); Ok(Self { - looper, waiting_requests: sender, }) } @@ -48,12 +52,12 @@ impl Backend for VllmBackend { // Send the query to the vLLM Engine if let Some(input_ids) = request.input_ids { - debug!("Attempt to queue new request"); - if let Err(err) = self.waiting_requests.send(Request { + debug!("Queuing new request"); + if let Err(err) = self.waiting_requests.send(VllmRequestContext { tokens: Arc::clone(&input_ids), params: request.parameters, stopping_params: request.stopping_parameters, - streamer: sender, + stream: sender, }) { warn!("Waiting Requests queue has been closed: {err}") } @@ -67,35 +71,55 @@ impl Backend for VllmBackend { } } -fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver) { +fn engine_background_loop( + mut engine: LlmEngine, + mut waiting_requests: Receiver, +) { info!("Starting vLLM engine background loop"); - + static DURATION_100_MS: Duration = Duration::from_millis(100); let mut in_flight_requests = HashMap::with_capacity(256); - loop { + 'outer: loop { if !waiting_requests.is_empty() { - let num_waiting_requests = waiting_requests.len(); - debug!( - "Adding {} requests to the vLLM engine", - num_waiting_requests - ); - - let mut requests = Vec::with_capacity(num_waiting_requests); - waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests); - - for request in requests { - match engine.add_request(&request.tokens, &request.params, &request.stopping_params) - { + match waiting_requests.recv_timeout(DURATION_100_MS) { + Ok(context) => match engine.add_request( + &context.tokens, + &context.params, + &context.stopping_params, + ) { Ok(request_id) => { debug!("Successfully scheduled request {request_id}"); - in_flight_requests.insert(request_id.to_string(), request); + in_flight_requests.insert(request_id.to_string(), context); } Err(err) => { warn!("Failed to schedule new request: {err}"); } + }, + Err(err) => match err { + RecvTimeoutError::Disconnected => break 'outer, + _ => {} // timeout all fine + }, + } + } + + if !in_flight_requests.is_empty() { + match engine.step() { + Ok(outputs) => outputs.iter().for_each(|output| { + let ctx = &in_flight_requests[&output.request_id]; + + // We only need to check on Err meaning the channel is not open anymore, so abort the request + if let Err(_) = ctx.stream.send(InferResult {}) { + debug!("Request {}'s channel dropped, aborting", &output.request_id); + in_flight_requests.remove(&output.request_id); + engine.abort_request(&output.request_id); + } + }), + Err(err) => { + error!("LLMEngine::step got an error: {err}"); } } } - engine.step(); + + spin_loop(); } info!("Shutting down vLLM engine background loop"); diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index 53e36e14..f3b6a761 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -1,9 +1,10 @@ use crate::errors::VllmBackendError; use crate::{sampling_params, tokens_prompt, TryToPyObject}; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; -use tracing::info; +use tracing::{info, instrument}; use uuid::Uuid; pub struct EngineArgs { @@ -29,9 +30,9 @@ impl IntoPyDict for EngineArgs { ), ], ) - .as_any(), + .as_any(), ) - .expect("Failed to create Python Dict from EngineArgs") + .expect("Failed to create Python Dict from EngineArgs") } } @@ -47,29 +48,32 @@ impl TryToPyObject for SamplingParams<'_> { let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( py, [ - ("seed", self.sampling_params.seed.into_py(py)), - ("n", 1.into_py(py)), - ("top_k", self.sampling_params.top_k.into_py(py)), - ("top_p", self.sampling_params.top_p.into_py(py)), - ("temperature", self.sampling_params.temperature.into_py(py)), + (intern!(py, "seed"), self.sampling_params.seed.into_py(py)), + (intern!(py, "n"), 1.into_py(py)), + (intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)), + (intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)), ( - "frequency_penalty", + intern!(py, "temperature"), + self.sampling_params.temperature.into_py(py), + ), + ( + intern!(py, "frequency_penalty"), self.sampling_params.frequency_penalty.into_py(py), ), ( - "repetition_penalty", + intern!(py, "repetition_penalty"), self.sampling_params.repetition_penalty.into_py(py), ), ( - "ignore_eos", + intern!(py, "ignore_eos"), self.stopping_params.ignore_eos_token.into_py(py), ), ( - "max_tokens", + intern!(py, "max_tokens"), self.stopping_params.max_new_tokens.into_py(py), ), ( - "stop", + intern!(py, "stop"), PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(), ), ], @@ -81,6 +85,47 @@ impl TryToPyObject for SamplingParams<'_> { } } +#[derive(Debug)] +pub struct CompletionOutput { + pub index: usize, + pub text: String, // TODO: SmallString? + pub token_ids: Vec, // TODO: TinyVec? + pub logprobs: Option>, // TODO: TinyVec? + pub finish_reason: Option, // lora_request: LATER +} + +#[derive(Debug)] +pub struct RequestOutput { + pub request_id: String, + pub outputs: Vec, + pub finished: bool, + // metrics: Vec // TODO +} + +impl<'py> FromPyObject<'py> for CompletionOutput { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + index: ob.getattr(intern!(py, "index"))?.extract()?, + text: ob.getattr(intern!(py, "text"))?.extract()?, + token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?, + logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, + finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?, + }) + } +} + +impl<'py> FromPyObject<'py> for RequestOutput { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + request_id: ob.getattr(intern!(py, "request_id"))?.extract()?, + outputs: ob.getattr(intern!(py, "outputs"))?.extract()?, + finished: ob.getattr(intern!(py, "finished"))?.extract()?, + }) + } +} + pub struct LlmEngine { engine: PyObject, } @@ -115,14 +160,14 @@ impl LlmEngine { ) -> Result<(), VllmBackendError> { Python::with_gil(|py| { // Create vllm.Tokens - let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py); + let kwargs = [(intern!(py, "prompt_token_ids"), prompt)].into_py_dict_bound(py); let py_tokens_prompt_class = tokens_prompt(py); let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?; let py_sampling_params = sampling_params.try_to_object(py)?; self.engine.call_method1( py, - "add_request", + intern!(py, "add_request"), ( PyString::new_bound(py, request_id), py_tokens_prompt, @@ -130,18 +175,27 @@ impl LlmEngine { ), )?; - self.engine.call_method0(py, "step") + self.engine.call_method0(py, intern!(py, "step")) })?; Ok(()) } + fn py_step(&self) -> Result, VllmBackendError> { + Ok(Python::with_gil(|py| { + self.engine + .call_method0(py, intern!(py, "step"))? + .extract::>(py) + })?) + } + pub fn from_engine_args(args: EngineArgs) -> Result { let engine = Self::py_from_engine_args(args)?; Ok(Self { engine }) } + #[instrument(skip_all)] pub fn add_request( &self, prompt: &[u32], @@ -159,5 +213,11 @@ impl LlmEngine { Ok(request_id) } - pub fn step(&mut self) {} + #[instrument(skip_all)] + pub fn abort_request(&self, _request_id: &str) {} + + #[instrument(skip_all)] + pub fn step(&mut self) -> Result, VllmBackendError> { + self.py_step() + } }