mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
backend(vllm): submit new request to vLLM engine
This commit is contained in:
parent
02e4b9ab32
commit
a7c2a470d6
4
Cargo.lock
generated
4
Cargo.lock
generated
@ -4449,11 +4449,15 @@ version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"clap 4.5.21",
|
||||
"log",
|
||||
"pyo3",
|
||||
"text-generation-router",
|
||||
"thiserror 2.0.11",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -13,3 +13,7 @@ text-generation-router = { path = "../../router" }
|
||||
thiserror = "2.0"
|
||||
tokio = { version = "1.43", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
uuid = { version = "1.11.0", features = ["v4"] }
|
||||
log = "0.4.22"
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = "0.3.18"
|
||||
|
@ -1,18 +1,39 @@
|
||||
use crate::errors::VllmBackendError;
|
||||
use crate::{EngineArgs, LlmEngine};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::thread::{spawn, JoinHandle};
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::validation::{
|
||||
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
type InferResult = Result<InferStreamResponse, InferError>;
|
||||
|
||||
struct Request {
|
||||
tokens: Arc<Vec<u32>>,
|
||||
params: ValidParameters,
|
||||
stopping_params: ValidStoppingParameters,
|
||||
streamer: UnboundedSender<InferResult>,
|
||||
}
|
||||
|
||||
pub struct VllmBackend {
|
||||
engine: LlmEngine,
|
||||
looper: JoinHandle<()>,
|
||||
waiting_requests: UnboundedSender<Request>,
|
||||
}
|
||||
|
||||
impl VllmBackend {
|
||||
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
||||
let engine = LlmEngine::from_engine_args(args)?;
|
||||
let (sender, receiver) = unbounded_channel();
|
||||
let looper = spawn(|| engine_background_loop(engine, receiver));
|
||||
Ok(Self {
|
||||
engine: LlmEngine::from_engine_args(args)?,
|
||||
looper,
|
||||
waiting_requests: sender,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -21,12 +42,61 @@ impl VllmBackend {
|
||||
impl Backend for VllmBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
_request: ValidGenerateRequest,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
todo!()
|
||||
let (sender, receiver) = unbounded_channel();
|
||||
|
||||
// Send the query to the vLLM Engine
|
||||
if let Some(input_ids) = request.input_ids {
|
||||
debug!("Attempt to queue new request");
|
||||
if let Err(err) = self.waiting_requests.send(Request {
|
||||
tokens: Arc::clone(&input_ids),
|
||||
params: request.parameters,
|
||||
stopping_params: request.stopping_parameters,
|
||||
streamer: sender,
|
||||
}) {
|
||||
warn!("Waiting Requests queue has been closed: {err}")
|
||||
}
|
||||
};
|
||||
|
||||
Ok(UnboundedReceiverStream::new(receiver))
|
||||
}
|
||||
|
||||
async fn health(&self, _current_health: bool) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver<Request>) {
|
||||
info!("Starting vLLM engine background loop");
|
||||
|
||||
let mut in_flight_requests = HashMap::with_capacity(256);
|
||||
loop {
|
||||
if !waiting_requests.is_empty() {
|
||||
let num_waiting_requests = waiting_requests.len();
|
||||
debug!(
|
||||
"Adding {} requests to the vLLM engine",
|
||||
num_waiting_requests
|
||||
);
|
||||
|
||||
let mut requests = Vec::with_capacity(num_waiting_requests);
|
||||
waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests);
|
||||
|
||||
for request in requests {
|
||||
match engine.add_request(&request.tokens, &request.params, &request.stopping_params)
|
||||
{
|
||||
Ok(request_id) => {
|
||||
debug!("Successfully scheduled request {request_id}");
|
||||
in_flight_requests.insert(request_id.to_string(), request);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to schedule new request: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
engine.step();
|
||||
}
|
||||
|
||||
info!("Shutting down vLLM engine background loop");
|
||||
}
|
||||
|
@ -1,6 +1,10 @@
|
||||
use crate::errors::VllmBackendError;
|
||||
use crate::{sampling_params, tokens_prompt, TryToPyObject};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyList};
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
|
||||
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct EngineArgs {
|
||||
pub model: String,
|
||||
@ -31,28 +35,51 @@ impl IntoPyDict for EngineArgs {
|
||||
}
|
||||
}
|
||||
|
||||
// impl IntoPy<PyObject> for EngineArgs {
|
||||
// fn into_py(self, py: Python<'_>) -> PyObject {
|
||||
// PyDict::from_sequence_bound(
|
||||
// PyList::new_bound(
|
||||
// py,
|
||||
// [
|
||||
// ("model", self.model.into_py(py)),
|
||||
// (
|
||||
// "pipeline_parallel_size",
|
||||
// self.pipeline_parallel_size.into_py(py),
|
||||
// ),
|
||||
// (
|
||||
// "tensor_parallel_size",
|
||||
// self.tensor_parallel_size.into_py(py),
|
||||
// ),
|
||||
// ],
|
||||
// )
|
||||
// .as_any(),
|
||||
// )
|
||||
// .expect("Failed to create Python Dict from EngineArgs")
|
||||
// }
|
||||
// }
|
||||
pub struct SamplingParams<'a> {
|
||||
sampling_params: &'a ValidParameters,
|
||||
stopping_params: &'a ValidStoppingParameters,
|
||||
}
|
||||
|
||||
impl TryToPyObject for SamplingParams<'_> {
|
||||
fn try_to_object(&self, py: Python<'_>) -> Result<PyObject, PyErr> {
|
||||
let py_sampling_params_class = sampling_params(py);
|
||||
|
||||
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
|
||||
py,
|
||||
[
|
||||
("seed", self.sampling_params.seed.into_py(py)),
|
||||
("n", 1.into_py(py)),
|
||||
("top_k", self.sampling_params.top_k.into_py(py)),
|
||||
("top_p", self.sampling_params.top_p.into_py(py)),
|
||||
("temperature", self.sampling_params.temperature.into_py(py)),
|
||||
(
|
||||
"frequency_penalty",
|
||||
self.sampling_params.frequency_penalty.into_py(py),
|
||||
),
|
||||
(
|
||||
"repetition_penalty",
|
||||
self.sampling_params.repetition_penalty.into_py(py),
|
||||
),
|
||||
(
|
||||
"ignore_eos",
|
||||
self.stopping_params.ignore_eos_token.into_py(py),
|
||||
),
|
||||
(
|
||||
"max_tokens",
|
||||
self.stopping_params.max_new_tokens.into_py(py),
|
||||
),
|
||||
(
|
||||
"stop",
|
||||
PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(),
|
||||
),
|
||||
],
|
||||
));
|
||||
|
||||
Ok(py_sampling_params_class
|
||||
.call_method_bound(py, "from_optional", (), Some(&kwargs?))?
|
||||
.to_object(py))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LlmEngine {
|
||||
engine: PyObject,
|
||||
@ -80,11 +107,63 @@ impl LlmEngine {
|
||||
})
|
||||
}
|
||||
|
||||
fn py_add_request(
|
||||
&self,
|
||||
request_id: &str,
|
||||
prompt: &[u32],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Result<(), VllmBackendError> {
|
||||
Python::with_gil(|py| {
|
||||
// Create vllm.Tokens
|
||||
let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py);
|
||||
let py_tokens_prompt_class = tokens_prompt(py);
|
||||
let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?;
|
||||
let py_sampling_params = sampling_params.try_to_object(py)?;
|
||||
|
||||
let _ = py.eval_bound(
|
||||
"print(type(params), params)",
|
||||
Some(&[("params", &py_sampling_params)].into_py_dict_bound(py)),
|
||||
None,
|
||||
);
|
||||
|
||||
self.engine.call_method1(
|
||||
py,
|
||||
"add_request",
|
||||
(
|
||||
PyString::new_bound(py, request_id),
|
||||
py_tokens_prompt,
|
||||
py_sampling_params,
|
||||
),
|
||||
)?;
|
||||
|
||||
self.engine.call_method0(py, "step")
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
|
||||
let engine = Self::py_from_engine_args(args)?;
|
||||
|
||||
Ok(Self { engine })
|
||||
}
|
||||
|
||||
pub fn add_request(
|
||||
&self,
|
||||
prompt: &[u32],
|
||||
sampling_params: &ValidParameters,
|
||||
stopping_params: &ValidStoppingParameters,
|
||||
) -> Result<Uuid, VllmBackendError> {
|
||||
let request_id = Uuid::new_v4();
|
||||
let sampling_params = SamplingParams {
|
||||
sampling_params,
|
||||
stopping_params,
|
||||
};
|
||||
self.py_add_request(&request_id.to_string(), prompt, sampling_params)?;
|
||||
|
||||
info!("Submitted new request: {request_id}");
|
||||
Ok(request_id)
|
||||
}
|
||||
|
||||
pub fn step(&mut self) {}
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
use pyo3::PyErr;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::server::WebServerError;
|
||||
use thiserror::Error;
|
||||
|
||||
@ -22,3 +23,9 @@ impl From<WebServerError> for VllmBackendError {
|
||||
Self::WebServer(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VllmBackendError> for InferError {
|
||||
fn from(value: VllmBackendError) -> Self {
|
||||
InferError::GenerationError(value.to_string())
|
||||
}
|
||||
}
|
||||
|
@ -5,3 +5,36 @@ mod errors;
|
||||
pub use backend::VllmBackend;
|
||||
pub use engine::{EngineArgs, LlmEngine};
|
||||
pub use errors::VllmBackendError;
|
||||
use pyo3::prelude::PyAnyMethods;
|
||||
use pyo3::sync::GILOnceCell;
|
||||
use pyo3::types::PyModule;
|
||||
use pyo3::{Py, PyAny, PyErr, PyObject, Python};
|
||||
|
||||
static PY_TOKENS_PROMPT_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
||||
static PY_SAMPLING_PARAMS_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn tokens_prompt(py: Python) -> &Py<PyAny> {
|
||||
PY_TOKENS_PROMPT_CLASS.get_or_init(py, || {
|
||||
PyModule::import_bound(py, "vllm.inputs")
|
||||
.expect("Failed to import vllm.inputs")
|
||||
.getattr("TokensPrompt")
|
||||
.expect("Failed to import vllm.inputs.TokensPrompt")
|
||||
.unbind()
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn sampling_params(py: Python) -> &Py<PyAny> {
|
||||
PY_SAMPLING_PARAMS_CLASS.get_or_init(py, || {
|
||||
PyModule::import_bound(py, "vllm")
|
||||
.expect("Failed to import vllm")
|
||||
.getattr("SamplingParams")
|
||||
.expect("Failed to import vllm.SamplingParams")
|
||||
.unbind()
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) trait TryToPyObject {
|
||||
fn try_to_object(&self, py: Python<'_>) -> Result<PyObject, PyErr>;
|
||||
}
|
||||
|
@ -73,6 +73,8 @@ impl Into<EngineArgs> for &Args {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), VllmBackendError> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let args = Args::parse();
|
||||
let backend = VllmBackend::from_engine_args((&args).into())?;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user