backend(vllm): expose FFI for CompletionOutput and RequestOutput on Rust side

This commit is contained in:
Morgan Funtowicz 2025-01-30 13:35:21 +01:00
parent 7028f5bce2
commit 32dffcff60
4 changed files with 131 additions and 45 deletions

1
Cargo.lock generated
View File

@ -4449,6 +4449,7 @@ version = "3.0.2-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"clap 4.5.21", "clap 4.5.21",
"crossbeam-channel",
"log", "log",
"pyo3", "pyo3",
"text-generation-router", "text-generation-router",

View File

@ -8,6 +8,7 @@ homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1.83" async-trait = "0.1.83"
clap = { version = "4.5.21", features = ["derive"] } clap = { version = "4.5.21", features = ["derive"] }
crossbeam-channel = "0.5"
pyo3 = { workspace = true } pyo3 = { workspace = true }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
thiserror = "2.0" thiserror = "2.0"

View File

@ -1,38 +1,42 @@
use crate::errors::VllmBackendError; use crate::errors::VllmBackendError;
use crate::{EngineArgs, LlmEngine}; use crate::{EngineArgs, LlmEngine};
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashMap; use crossbeam_channel::internal::SelectHandle;
use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender};
use std::collections::{HashMap, HashSet};
use std::hint::spin_loop;
use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use std::thread::{spawn, JoinHandle}; use std::thread::{spawn, JoinHandle};
use std::time::Duration;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::{ use text_generation_router::validation::{
ValidGenerateRequest, ValidParameters, ValidStoppingParameters, ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
}; };
use text_generation_router::Token;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, info, warn}; use tracing::{debug, error, info, warn};
type InferResult = Result<InferStreamResponse, InferError>; type InferResult = Result<InferStreamResponse, InferError>;
struct Request { struct VllmRequestContext {
tokens: Arc<Vec<u32>>, tokens: Arc<Vec<u32>>,
params: ValidParameters, params: ValidParameters,
stopping_params: ValidStoppingParameters, stopping_params: ValidStoppingParameters,
streamer: UnboundedSender<InferResult>, stream: UnboundedSender<InferResult>,
} }
pub struct VllmBackend { pub struct VllmBackend {
looper: JoinHandle<()>, waiting_requests: Sender<VllmRequestContext>,
waiting_requests: UnboundedSender<Request>,
} }
impl VllmBackend { impl VllmBackend {
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> { pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
let engine = LlmEngine::from_engine_args(args)?; let engine = LlmEngine::from_engine_args(args)?;
let (sender, receiver) = unbounded_channel(); let (sender, receiver) = unbounded();
let looper = spawn(|| engine_background_loop(engine, receiver)); let looper = spawn(|| engine_background_loop(engine, receiver));
Ok(Self { Ok(Self {
looper,
waiting_requests: sender, waiting_requests: sender,
}) })
} }
@ -48,12 +52,12 @@ impl Backend for VllmBackend {
// Send the query to the vLLM Engine // Send the query to the vLLM Engine
if let Some(input_ids) = request.input_ids { if let Some(input_ids) = request.input_ids {
debug!("Attempt to queue new request"); debug!("Queuing new request");
if let Err(err) = self.waiting_requests.send(Request { if let Err(err) = self.waiting_requests.send(VllmRequestContext {
tokens: Arc::clone(&input_ids), tokens: Arc::clone(&input_ids),
params: request.parameters, params: request.parameters,
stopping_params: request.stopping_parameters, stopping_params: request.stopping_parameters,
streamer: sender, stream: sender,
}) { }) {
warn!("Waiting Requests queue has been closed: {err}") warn!("Waiting Requests queue has been closed: {err}")
} }
@ -67,35 +71,55 @@ impl Backend for VllmBackend {
} }
} }
fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver<Request>) { fn engine_background_loop(
mut engine: LlmEngine,
mut waiting_requests: Receiver<VllmRequestContext>,
) {
info!("Starting vLLM engine background loop"); info!("Starting vLLM engine background loop");
static DURATION_100_MS: Duration = Duration::from_millis(100);
let mut in_flight_requests = HashMap::with_capacity(256); let mut in_flight_requests = HashMap::with_capacity(256);
loop { 'outer: loop {
if !waiting_requests.is_empty() { if !waiting_requests.is_empty() {
let num_waiting_requests = waiting_requests.len(); match waiting_requests.recv_timeout(DURATION_100_MS) {
debug!( Ok(context) => match engine.add_request(
"Adding {} requests to the vLLM engine", &context.tokens,
num_waiting_requests &context.params,
); &context.stopping_params,
) {
let mut requests = Vec::with_capacity(num_waiting_requests);
waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests);
for request in requests {
match engine.add_request(&request.tokens, &request.params, &request.stopping_params)
{
Ok(request_id) => { Ok(request_id) => {
debug!("Successfully scheduled request {request_id}"); debug!("Successfully scheduled request {request_id}");
in_flight_requests.insert(request_id.to_string(), request); in_flight_requests.insert(request_id.to_string(), context);
} }
Err(err) => { Err(err) => {
warn!("Failed to schedule new request: {err}"); warn!("Failed to schedule new request: {err}");
} }
},
Err(err) => match err {
RecvTimeoutError::Disconnected => break 'outer,
_ => {} // timeout all fine
},
}
}
if !in_flight_requests.is_empty() {
match engine.step() {
Ok(outputs) => outputs.iter().for_each(|output| {
let ctx = &in_flight_requests[&output.request_id];
// We only need to check on Err meaning the channel is not open anymore, so abort the request
if let Err(_) = ctx.stream.send(InferResult {}) {
debug!("Request {}'s channel dropped, aborting", &output.request_id);
in_flight_requests.remove(&output.request_id);
engine.abort_request(&output.request_id);
}
}),
Err(err) => {
error!("LLMEngine::step got an error: {err}");
} }
} }
} }
engine.step();
spin_loop();
} }
info!("Shutting down vLLM engine background loop"); info!("Shutting down vLLM engine background loop");

View File

@ -1,9 +1,10 @@
use crate::errors::VllmBackendError; use crate::errors::VllmBackendError;
use crate::{sampling_params, tokens_prompt, TryToPyObject}; use crate::{sampling_params, tokens_prompt, TryToPyObject};
use pyo3::intern;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
use tracing::info; use tracing::{info, instrument};
use uuid::Uuid; use uuid::Uuid;
pub struct EngineArgs { pub struct EngineArgs {
@ -47,29 +48,32 @@ impl TryToPyObject for SamplingParams<'_> {
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
py, py,
[ [
("seed", self.sampling_params.seed.into_py(py)), (intern!(py, "seed"), self.sampling_params.seed.into_py(py)),
("n", 1.into_py(py)), (intern!(py, "n"), 1.into_py(py)),
("top_k", self.sampling_params.top_k.into_py(py)), (intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)),
("top_p", self.sampling_params.top_p.into_py(py)), (intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)),
("temperature", self.sampling_params.temperature.into_py(py)),
( (
"frequency_penalty", intern!(py, "temperature"),
self.sampling_params.temperature.into_py(py),
),
(
intern!(py, "frequency_penalty"),
self.sampling_params.frequency_penalty.into_py(py), self.sampling_params.frequency_penalty.into_py(py),
), ),
( (
"repetition_penalty", intern!(py, "repetition_penalty"),
self.sampling_params.repetition_penalty.into_py(py), self.sampling_params.repetition_penalty.into_py(py),
), ),
( (
"ignore_eos", intern!(py, "ignore_eos"),
self.stopping_params.ignore_eos_token.into_py(py), self.stopping_params.ignore_eos_token.into_py(py),
), ),
( (
"max_tokens", intern!(py, "max_tokens"),
self.stopping_params.max_new_tokens.into_py(py), self.stopping_params.max_new_tokens.into_py(py),
), ),
( (
"stop", intern!(py, "stop"),
PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(), PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(),
), ),
], ],
@ -81,6 +85,47 @@ impl TryToPyObject for SamplingParams<'_> {
} }
} }
#[derive(Debug)]
pub struct CompletionOutput {
pub index: usize,
pub text: String, // TODO: SmallString?
pub token_ids: Vec<u32>, // TODO: TinyVec?
pub logprobs: Option<Vec<f32>>, // TODO: TinyVec?
pub finish_reason: Option<String>, // lora_request: LATER
}
#[derive(Debug)]
pub struct RequestOutput {
pub request_id: String,
pub outputs: Vec<CompletionOutput>,
pub finished: bool,
// metrics: Vec<RequestMetrics> // TODO
}
impl<'py> FromPyObject<'py> for CompletionOutput {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let py = ob.py();
Ok(Self {
index: ob.getattr(intern!(py, "index"))?.extract()?,
text: ob.getattr(intern!(py, "text"))?.extract()?,
token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?,
logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?,
finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?,
})
}
}
impl<'py> FromPyObject<'py> for RequestOutput {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let py = ob.py();
Ok(Self {
request_id: ob.getattr(intern!(py, "request_id"))?.extract()?,
outputs: ob.getattr(intern!(py, "outputs"))?.extract()?,
finished: ob.getattr(intern!(py, "finished"))?.extract()?,
})
}
}
pub struct LlmEngine { pub struct LlmEngine {
engine: PyObject, engine: PyObject,
} }
@ -115,14 +160,14 @@ impl LlmEngine {
) -> Result<(), VllmBackendError> { ) -> Result<(), VllmBackendError> {
Python::with_gil(|py| { Python::with_gil(|py| {
// Create vllm.Tokens // Create vllm.Tokens
let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py); let kwargs = [(intern!(py, "prompt_token_ids"), prompt)].into_py_dict_bound(py);
let py_tokens_prompt_class = tokens_prompt(py); let py_tokens_prompt_class = tokens_prompt(py);
let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?; let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?;
let py_sampling_params = sampling_params.try_to_object(py)?; let py_sampling_params = sampling_params.try_to_object(py)?;
self.engine.call_method1( self.engine.call_method1(
py, py,
"add_request", intern!(py, "add_request"),
( (
PyString::new_bound(py, request_id), PyString::new_bound(py, request_id),
py_tokens_prompt, py_tokens_prompt,
@ -130,18 +175,27 @@ impl LlmEngine {
), ),
)?; )?;
self.engine.call_method0(py, "step") self.engine.call_method0(py, intern!(py, "step"))
})?; })?;
Ok(()) Ok(())
} }
fn py_step(&self) -> Result<Vec<RequestOutput>, VllmBackendError> {
Ok(Python::with_gil(|py| {
self.engine
.call_method0(py, intern!(py, "step"))?
.extract::<Vec<RequestOutput>>(py)
})?)
}
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> { pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
let engine = Self::py_from_engine_args(args)?; let engine = Self::py_from_engine_args(args)?;
Ok(Self { engine }) Ok(Self { engine })
} }
#[instrument(skip_all)]
pub fn add_request( pub fn add_request(
&self, &self,
prompt: &[u32], prompt: &[u32],
@ -159,5 +213,11 @@ impl LlmEngine {
Ok(request_id) Ok(request_id)
} }
pub fn step(&mut self) {} #[instrument(skip_all)]
pub fn abort_request(&self, _request_id: &str) {}
#[instrument(skip_all)]
pub fn step(&mut self) -> Result<Vec<RequestOutput>, VllmBackendError> {
self.py_step()
}
} }