backend(vllm): submit new request to vLLM engine

This commit is contained in:
Morgan Funtowicz 2025-01-27 22:39:35 +01:00
parent 02e4b9ab32
commit a7c2a470d6
7 changed files with 227 additions and 28 deletions

4
Cargo.lock generated
View File

@ -4449,11 +4449,15 @@ version = "3.0.2-dev0"
dependencies = [
"async-trait",
"clap 4.5.21",
"log",
"pyo3",
"text-generation-router",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tracing",
"tracing-subscriber",
"uuid",
]
[[package]]

View File

@ -13,3 +13,7 @@ text-generation-router = { path = "../../router" }
thiserror = "2.0"
tokio = { version = "1.43", features = ["full"] }
tokio-stream = "0.1"
uuid = { version = "1.11.0", features = ["v4"] }
log = "0.4.22"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

View File

@ -1,18 +1,39 @@
use crate::errors::VllmBackendError;
use crate::{EngineArgs, LlmEngine};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::thread::{spawn, JoinHandle};
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::validation::{
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, info, warn};
type InferResult = Result<InferStreamResponse, InferError>;
struct Request {
tokens: Arc<Vec<u32>>,
params: ValidParameters,
stopping_params: ValidStoppingParameters,
streamer: UnboundedSender<InferResult>,
}
pub struct VllmBackend {
engine: LlmEngine,
looper: JoinHandle<()>,
waiting_requests: UnboundedSender<Request>,
}
impl VllmBackend {
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
let engine = LlmEngine::from_engine_args(args)?;
let (sender, receiver) = unbounded_channel();
let looper = spawn(|| engine_background_loop(engine, receiver));
Ok(Self {
engine: LlmEngine::from_engine_args(args)?,
looper,
waiting_requests: sender,
})
}
}
@ -21,12 +42,61 @@ impl VllmBackend {
impl Backend for VllmBackend {
fn schedule(
&self,
_request: ValidGenerateRequest,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
todo!()
let (sender, receiver) = unbounded_channel();
// Send the query to the vLLM Engine
if let Some(input_ids) = request.input_ids {
debug!("Attempt to queue new request");
if let Err(err) = self.waiting_requests.send(Request {
tokens: Arc::clone(&input_ids),
params: request.parameters,
stopping_params: request.stopping_parameters,
streamer: sender,
}) {
warn!("Waiting Requests queue has been closed: {err}")
}
};
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}
fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver<Request>) {
info!("Starting vLLM engine background loop");
let mut in_flight_requests = HashMap::with_capacity(256);
loop {
if !waiting_requests.is_empty() {
let num_waiting_requests = waiting_requests.len();
debug!(
"Adding {} requests to the vLLM engine",
num_waiting_requests
);
let mut requests = Vec::with_capacity(num_waiting_requests);
waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests);
for request in requests {
match engine.add_request(&request.tokens, &request.params, &request.stopping_params)
{
Ok(request_id) => {
debug!("Successfully scheduled request {request_id}");
in_flight_requests.insert(request_id.to_string(), request);
}
Err(err) => {
warn!("Failed to schedule new request: {err}");
}
}
}
}
engine.step();
}
info!("Shutting down vLLM engine background loop");
}

View File

@ -1,6 +1,10 @@
use crate::errors::VllmBackendError;
use crate::{sampling_params, tokens_prompt, TryToPyObject};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyList};
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
use tracing::info;
use uuid::Uuid;
pub struct EngineArgs {
pub model: String,
@ -31,28 +35,51 @@ impl IntoPyDict for EngineArgs {
}
}
// impl IntoPy<PyObject> for EngineArgs {
// fn into_py(self, py: Python<'_>) -> PyObject {
// PyDict::from_sequence_bound(
// PyList::new_bound(
// py,
// [
// ("model", self.model.into_py(py)),
// (
// "pipeline_parallel_size",
// self.pipeline_parallel_size.into_py(py),
// ),
// (
// "tensor_parallel_size",
// self.tensor_parallel_size.into_py(py),
// ),
// ],
// )
// .as_any(),
// )
// .expect("Failed to create Python Dict from EngineArgs")
// }
// }
pub struct SamplingParams<'a> {
sampling_params: &'a ValidParameters,
stopping_params: &'a ValidStoppingParameters,
}
impl TryToPyObject for SamplingParams<'_> {
fn try_to_object(&self, py: Python<'_>) -> Result<PyObject, PyErr> {
let py_sampling_params_class = sampling_params(py);
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
py,
[
("seed", self.sampling_params.seed.into_py(py)),
("n", 1.into_py(py)),
("top_k", self.sampling_params.top_k.into_py(py)),
("top_p", self.sampling_params.top_p.into_py(py)),
("temperature", self.sampling_params.temperature.into_py(py)),
(
"frequency_penalty",
self.sampling_params.frequency_penalty.into_py(py),
),
(
"repetition_penalty",
self.sampling_params.repetition_penalty.into_py(py),
),
(
"ignore_eos",
self.stopping_params.ignore_eos_token.into_py(py),
),
(
"max_tokens",
self.stopping_params.max_new_tokens.into_py(py),
),
(
"stop",
PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(),
),
],
));
Ok(py_sampling_params_class
.call_method_bound(py, "from_optional", (), Some(&kwargs?))?
.to_object(py))
}
}
pub struct LlmEngine {
engine: PyObject,
@ -80,11 +107,63 @@ impl LlmEngine {
})
}
fn py_add_request(
&self,
request_id: &str,
prompt: &[u32],
sampling_params: SamplingParams,
) -> Result<(), VllmBackendError> {
Python::with_gil(|py| {
// Create vllm.Tokens
let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py);
let py_tokens_prompt_class = tokens_prompt(py);
let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?;
let py_sampling_params = sampling_params.try_to_object(py)?;
let _ = py.eval_bound(
"print(type(params), params)",
Some(&[("params", &py_sampling_params)].into_py_dict_bound(py)),
None,
);
self.engine.call_method1(
py,
"add_request",
(
PyString::new_bound(py, request_id),
py_tokens_prompt,
py_sampling_params,
),
)?;
self.engine.call_method0(py, "step")
})?;
Ok(())
}
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
let engine = Self::py_from_engine_args(args)?;
Ok(Self { engine })
}
pub fn add_request(
&self,
prompt: &[u32],
sampling_params: &ValidParameters,
stopping_params: &ValidStoppingParameters,
) -> Result<Uuid, VllmBackendError> {
let request_id = Uuid::new_v4();
let sampling_params = SamplingParams {
sampling_params,
stopping_params,
};
self.py_add_request(&request_id.to_string(), prompt, sampling_params)?;
info!("Submitted new request: {request_id}");
Ok(request_id)
}
pub fn step(&mut self) {}
}

View File

@ -1,4 +1,5 @@
use pyo3::PyErr;
use text_generation_router::infer::InferError;
use text_generation_router::server::WebServerError;
use thiserror::Error;
@ -22,3 +23,9 @@ impl From<WebServerError> for VllmBackendError {
Self::WebServer(value)
}
}
impl From<VllmBackendError> for InferError {
fn from(value: VllmBackendError) -> Self {
InferError::GenerationError(value.to_string())
}
}

View File

@ -5,3 +5,36 @@ mod errors;
pub use backend::VllmBackend;
pub use engine::{EngineArgs, LlmEngine};
pub use errors::VllmBackendError;
use pyo3::prelude::PyAnyMethods;
use pyo3::sync::GILOnceCell;
use pyo3::types::PyModule;
use pyo3::{Py, PyAny, PyErr, PyObject, Python};
static PY_TOKENS_PROMPT_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
static PY_SAMPLING_PARAMS_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
#[inline]
pub(crate) fn tokens_prompt(py: Python) -> &Py<PyAny> {
PY_TOKENS_PROMPT_CLASS.get_or_init(py, || {
PyModule::import_bound(py, "vllm.inputs")
.expect("Failed to import vllm.inputs")
.getattr("TokensPrompt")
.expect("Failed to import vllm.inputs.TokensPrompt")
.unbind()
})
}
#[inline]
pub(crate) fn sampling_params(py: Python) -> &Py<PyAny> {
PY_SAMPLING_PARAMS_CLASS.get_or_init(py, || {
PyModule::import_bound(py, "vllm")
.expect("Failed to import vllm")
.getattr("SamplingParams")
.expect("Failed to import vllm.SamplingParams")
.unbind()
})
}
pub(crate) trait TryToPyObject {
fn try_to_object(&self, py: Python<'_>) -> Result<PyObject, PyErr>;
}

View File

@ -73,6 +73,8 @@ impl Into<EngineArgs> for &Args {
#[tokio::main]
async fn main() -> Result<(), VllmBackendError> {
tracing_subscriber::fmt::init();
let args = Args::parse();
let backend = VllmBackend::from_engine_args((&args).into())?;