mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
backend(vllm): submit new request to vLLM engine
This commit is contained in:
parent
02e4b9ab32
commit
a7c2a470d6
4
Cargo.lock
generated
4
Cargo.lock
generated
@ -4449,11 +4449,15 @@ version = "3.0.2-dev0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
|
"log",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror 2.0.11",
|
"thiserror 2.0.11",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
|
"tracing",
|
||||||
|
"tracing-subscriber",
|
||||||
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -13,3 +13,7 @@ text-generation-router = { path = "../../router" }
|
|||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
tokio = { version = "1.43", features = ["full"] }
|
tokio = { version = "1.43", features = ["full"] }
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
|
uuid = { version = "1.11.0", features = ["v4"] }
|
||||||
|
log = "0.4.22"
|
||||||
|
tracing = "0.1.40"
|
||||||
|
tracing-subscriber = "0.3.18"
|
||||||
|
@ -1,18 +1,39 @@
|
|||||||
use crate::errors::VllmBackendError;
|
use crate::errors::VllmBackendError;
|
||||||
use crate::{EngineArgs, LlmEngine};
|
use crate::{EngineArgs, LlmEngine};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::thread::{spawn, JoinHandle};
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::{
|
||||||
|
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
||||||
|
};
|
||||||
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
type InferResult = Result<InferStreamResponse, InferError>;
|
||||||
|
|
||||||
|
struct Request {
|
||||||
|
tokens: Arc<Vec<u32>>,
|
||||||
|
params: ValidParameters,
|
||||||
|
stopping_params: ValidStoppingParameters,
|
||||||
|
streamer: UnboundedSender<InferResult>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct VllmBackend {
|
pub struct VllmBackend {
|
||||||
engine: LlmEngine,
|
looper: JoinHandle<()>,
|
||||||
|
waiting_requests: UnboundedSender<Request>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VllmBackend {
|
impl VllmBackend {
|
||||||
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
||||||
|
let engine = LlmEngine::from_engine_args(args)?;
|
||||||
|
let (sender, receiver) = unbounded_channel();
|
||||||
|
let looper = spawn(|| engine_background_loop(engine, receiver));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
engine: LlmEngine::from_engine_args(args)?,
|
looper,
|
||||||
|
waiting_requests: sender,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -21,12 +42,61 @@ impl VllmBackend {
|
|||||||
impl Backend for VllmBackend {
|
impl Backend for VllmBackend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
_request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
todo!()
|
let (sender, receiver) = unbounded_channel();
|
||||||
|
|
||||||
|
// Send the query to the vLLM Engine
|
||||||
|
if let Some(input_ids) = request.input_ids {
|
||||||
|
debug!("Attempt to queue new request");
|
||||||
|
if let Err(err) = self.waiting_requests.send(Request {
|
||||||
|
tokens: Arc::clone(&input_ids),
|
||||||
|
params: request.parameters,
|
||||||
|
stopping_params: request.stopping_parameters,
|
||||||
|
streamer: sender,
|
||||||
|
}) {
|
||||||
|
warn!("Waiting Requests queue has been closed: {err}")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(UnboundedReceiverStream::new(receiver))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _current_health: bool) -> bool {
|
async fn health(&self, _current_health: bool) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver<Request>) {
|
||||||
|
info!("Starting vLLM engine background loop");
|
||||||
|
|
||||||
|
let mut in_flight_requests = HashMap::with_capacity(256);
|
||||||
|
loop {
|
||||||
|
if !waiting_requests.is_empty() {
|
||||||
|
let num_waiting_requests = waiting_requests.len();
|
||||||
|
debug!(
|
||||||
|
"Adding {} requests to the vLLM engine",
|
||||||
|
num_waiting_requests
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut requests = Vec::with_capacity(num_waiting_requests);
|
||||||
|
waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests);
|
||||||
|
|
||||||
|
for request in requests {
|
||||||
|
match engine.add_request(&request.tokens, &request.params, &request.stopping_params)
|
||||||
|
{
|
||||||
|
Ok(request_id) => {
|
||||||
|
debug!("Successfully scheduled request {request_id}");
|
||||||
|
in_flight_requests.insert(request_id.to_string(), request);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Failed to schedule new request: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
engine.step();
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Shutting down vLLM engine background loop");
|
||||||
|
}
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
use crate::errors::VllmBackendError;
|
use crate::errors::VllmBackendError;
|
||||||
|
use crate::{sampling_params, tokens_prompt, TryToPyObject};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::{IntoPyDict, PyDict, PyList};
|
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
|
||||||
|
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
|
||||||
|
use tracing::info;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub struct EngineArgs {
|
pub struct EngineArgs {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
@ -31,28 +35,51 @@ impl IntoPyDict for EngineArgs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// impl IntoPy<PyObject> for EngineArgs {
|
pub struct SamplingParams<'a> {
|
||||||
// fn into_py(self, py: Python<'_>) -> PyObject {
|
sampling_params: &'a ValidParameters,
|
||||||
// PyDict::from_sequence_bound(
|
stopping_params: &'a ValidStoppingParameters,
|
||||||
// PyList::new_bound(
|
}
|
||||||
// py,
|
|
||||||
// [
|
impl TryToPyObject for SamplingParams<'_> {
|
||||||
// ("model", self.model.into_py(py)),
|
fn try_to_object(&self, py: Python<'_>) -> Result<PyObject, PyErr> {
|
||||||
// (
|
let py_sampling_params_class = sampling_params(py);
|
||||||
// "pipeline_parallel_size",
|
|
||||||
// self.pipeline_parallel_size.into_py(py),
|
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
|
||||||
// ),
|
py,
|
||||||
// (
|
[
|
||||||
// "tensor_parallel_size",
|
("seed", self.sampling_params.seed.into_py(py)),
|
||||||
// self.tensor_parallel_size.into_py(py),
|
("n", 1.into_py(py)),
|
||||||
// ),
|
("top_k", self.sampling_params.top_k.into_py(py)),
|
||||||
// ],
|
("top_p", self.sampling_params.top_p.into_py(py)),
|
||||||
// )
|
("temperature", self.sampling_params.temperature.into_py(py)),
|
||||||
// .as_any(),
|
(
|
||||||
// )
|
"frequency_penalty",
|
||||||
// .expect("Failed to create Python Dict from EngineArgs")
|
self.sampling_params.frequency_penalty.into_py(py),
|
||||||
// }
|
),
|
||||||
// }
|
(
|
||||||
|
"repetition_penalty",
|
||||||
|
self.sampling_params.repetition_penalty.into_py(py),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ignore_eos",
|
||||||
|
self.stopping_params.ignore_eos_token.into_py(py),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"max_tokens",
|
||||||
|
self.stopping_params.max_new_tokens.into_py(py),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"stop",
|
||||||
|
PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok(py_sampling_params_class
|
||||||
|
.call_method_bound(py, "from_optional", (), Some(&kwargs?))?
|
||||||
|
.to_object(py))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct LlmEngine {
|
pub struct LlmEngine {
|
||||||
engine: PyObject,
|
engine: PyObject,
|
||||||
@ -80,11 +107,63 @@ impl LlmEngine {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn py_add_request(
|
||||||
|
&self,
|
||||||
|
request_id: &str,
|
||||||
|
prompt: &[u32],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
) -> Result<(), VllmBackendError> {
|
||||||
|
Python::with_gil(|py| {
|
||||||
|
// Create vllm.Tokens
|
||||||
|
let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py);
|
||||||
|
let py_tokens_prompt_class = tokens_prompt(py);
|
||||||
|
let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?;
|
||||||
|
let py_sampling_params = sampling_params.try_to_object(py)?;
|
||||||
|
|
||||||
|
let _ = py.eval_bound(
|
||||||
|
"print(type(params), params)",
|
||||||
|
Some(&[("params", &py_sampling_params)].into_py_dict_bound(py)),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
self.engine.call_method1(
|
||||||
|
py,
|
||||||
|
"add_request",
|
||||||
|
(
|
||||||
|
PyString::new_bound(py, request_id),
|
||||||
|
py_tokens_prompt,
|
||||||
|
py_sampling_params,
|
||||||
|
),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
self.engine.call_method0(py, "step")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
|
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
|
||||||
let engine = Self::py_from_engine_args(args)?;
|
let engine = Self::py_from_engine_args(args)?;
|
||||||
|
|
||||||
Ok(Self { engine })
|
Ok(Self { engine })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add_request(
|
||||||
|
&self,
|
||||||
|
prompt: &[u32],
|
||||||
|
sampling_params: &ValidParameters,
|
||||||
|
stopping_params: &ValidStoppingParameters,
|
||||||
|
) -> Result<Uuid, VllmBackendError> {
|
||||||
|
let request_id = Uuid::new_v4();
|
||||||
|
let sampling_params = SamplingParams {
|
||||||
|
sampling_params,
|
||||||
|
stopping_params,
|
||||||
|
};
|
||||||
|
self.py_add_request(&request_id.to_string(), prompt, sampling_params)?;
|
||||||
|
|
||||||
|
info!("Submitted new request: {request_id}");
|
||||||
|
Ok(request_id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn step(&mut self) {}
|
pub fn step(&mut self) {}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use pyo3::PyErr;
|
use pyo3::PyErr;
|
||||||
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::server::WebServerError;
|
use text_generation_router::server::WebServerError;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
@ -22,3 +23,9 @@ impl From<WebServerError> for VllmBackendError {
|
|||||||
Self::WebServer(value)
|
Self::WebServer(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<VllmBackendError> for InferError {
|
||||||
|
fn from(value: VllmBackendError) -> Self {
|
||||||
|
InferError::GenerationError(value.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -5,3 +5,36 @@ mod errors;
|
|||||||
pub use backend::VllmBackend;
|
pub use backend::VllmBackend;
|
||||||
pub use engine::{EngineArgs, LlmEngine};
|
pub use engine::{EngineArgs, LlmEngine};
|
||||||
pub use errors::VllmBackendError;
|
pub use errors::VllmBackendError;
|
||||||
|
use pyo3::prelude::PyAnyMethods;
|
||||||
|
use pyo3::sync::GILOnceCell;
|
||||||
|
use pyo3::types::PyModule;
|
||||||
|
use pyo3::{Py, PyAny, PyErr, PyObject, Python};
|
||||||
|
|
||||||
|
static PY_TOKENS_PROMPT_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
||||||
|
static PY_SAMPLING_PARAMS_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn tokens_prompt(py: Python) -> &Py<PyAny> {
|
||||||
|
PY_TOKENS_PROMPT_CLASS.get_or_init(py, || {
|
||||||
|
PyModule::import_bound(py, "vllm.inputs")
|
||||||
|
.expect("Failed to import vllm.inputs")
|
||||||
|
.getattr("TokensPrompt")
|
||||||
|
.expect("Failed to import vllm.inputs.TokensPrompt")
|
||||||
|
.unbind()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn sampling_params(py: Python) -> &Py<PyAny> {
|
||||||
|
PY_SAMPLING_PARAMS_CLASS.get_or_init(py, || {
|
||||||
|
PyModule::import_bound(py, "vllm")
|
||||||
|
.expect("Failed to import vllm")
|
||||||
|
.getattr("SamplingParams")
|
||||||
|
.expect("Failed to import vllm.SamplingParams")
|
||||||
|
.unbind()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) trait TryToPyObject {
|
||||||
|
fn try_to_object(&self, py: Python<'_>) -> Result<PyObject, PyErr>;
|
||||||
|
}
|
||||||
|
@ -73,6 +73,8 @@ impl Into<EngineArgs> for &Args {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), VllmBackendError> {
|
async fn main() -> Result<(), VllmBackendError> {
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let backend = VllmBackend::from_engine_args((&args).into())?;
|
let backend = VllmBackend::from_engine_args((&args).into())?;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user