mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
backend(vllm): expose FFI for CompletionOutput and RequestOutput on Rust side
This commit is contained in:
parent
7028f5bce2
commit
32dffcff60
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4449,6 +4449,7 @@ version = "3.0.2-dev0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
|
"crossbeam-channel",
|
||||||
"log",
|
"log",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
|
@ -8,6 +8,7 @@ homepage.workspace = true
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.83"
|
async-trait = "0.1.83"
|
||||||
clap = { version = "4.5.21", features = ["derive"] }
|
clap = { version = "4.5.21", features = ["derive"] }
|
||||||
|
crossbeam-channel = "0.5"
|
||||||
pyo3 = { workspace = true }
|
pyo3 = { workspace = true }
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
|
@ -1,38 +1,42 @@
|
|||||||
use crate::errors::VllmBackendError;
|
use crate::errors::VllmBackendError;
|
||||||
use crate::{EngineArgs, LlmEngine};
|
use crate::{EngineArgs, LlmEngine};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::HashMap;
|
use crossbeam_channel::internal::SelectHandle;
|
||||||
|
use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::hint::spin_loop;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread::{spawn, JoinHandle};
|
use std::thread::{spawn, JoinHandle};
|
||||||
|
use std::time::Duration;
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{
|
||||||
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
|
use text_generation_router::Token;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
type InferResult = Result<InferStreamResponse, InferError>;
|
type InferResult = Result<InferStreamResponse, InferError>;
|
||||||
|
|
||||||
struct Request {
|
struct VllmRequestContext {
|
||||||
tokens: Arc<Vec<u32>>,
|
tokens: Arc<Vec<u32>>,
|
||||||
params: ValidParameters,
|
params: ValidParameters,
|
||||||
stopping_params: ValidStoppingParameters,
|
stopping_params: ValidStoppingParameters,
|
||||||
streamer: UnboundedSender<InferResult>,
|
stream: UnboundedSender<InferResult>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct VllmBackend {
|
pub struct VllmBackend {
|
||||||
looper: JoinHandle<()>,
|
waiting_requests: Sender<VllmRequestContext>,
|
||||||
waiting_requests: UnboundedSender<Request>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VllmBackend {
|
impl VllmBackend {
|
||||||
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
||||||
let engine = LlmEngine::from_engine_args(args)?;
|
let engine = LlmEngine::from_engine_args(args)?;
|
||||||
let (sender, receiver) = unbounded_channel();
|
let (sender, receiver) = unbounded();
|
||||||
let looper = spawn(|| engine_background_loop(engine, receiver));
|
let looper = spawn(|| engine_background_loop(engine, receiver));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
looper,
|
|
||||||
waiting_requests: sender,
|
waiting_requests: sender,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -48,12 +52,12 @@ impl Backend for VllmBackend {
|
|||||||
|
|
||||||
// Send the query to the vLLM Engine
|
// Send the query to the vLLM Engine
|
||||||
if let Some(input_ids) = request.input_ids {
|
if let Some(input_ids) = request.input_ids {
|
||||||
debug!("Attempt to queue new request");
|
debug!("Queuing new request");
|
||||||
if let Err(err) = self.waiting_requests.send(Request {
|
if let Err(err) = self.waiting_requests.send(VllmRequestContext {
|
||||||
tokens: Arc::clone(&input_ids),
|
tokens: Arc::clone(&input_ids),
|
||||||
params: request.parameters,
|
params: request.parameters,
|
||||||
stopping_params: request.stopping_parameters,
|
stopping_params: request.stopping_parameters,
|
||||||
streamer: sender,
|
stream: sender,
|
||||||
}) {
|
}) {
|
||||||
warn!("Waiting Requests queue has been closed: {err}")
|
warn!("Waiting Requests queue has been closed: {err}")
|
||||||
}
|
}
|
||||||
@ -67,35 +71,55 @@ impl Backend for VllmBackend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn engine_background_loop(mut engine: LlmEngine, mut waiting_requests: UnboundedReceiver<Request>) {
|
fn engine_background_loop(
|
||||||
|
mut engine: LlmEngine,
|
||||||
|
mut waiting_requests: Receiver<VllmRequestContext>,
|
||||||
|
) {
|
||||||
info!("Starting vLLM engine background loop");
|
info!("Starting vLLM engine background loop");
|
||||||
|
static DURATION_100_MS: Duration = Duration::from_millis(100);
|
||||||
let mut in_flight_requests = HashMap::with_capacity(256);
|
let mut in_flight_requests = HashMap::with_capacity(256);
|
||||||
loop {
|
'outer: loop {
|
||||||
if !waiting_requests.is_empty() {
|
if !waiting_requests.is_empty() {
|
||||||
let num_waiting_requests = waiting_requests.len();
|
match waiting_requests.recv_timeout(DURATION_100_MS) {
|
||||||
debug!(
|
Ok(context) => match engine.add_request(
|
||||||
"Adding {} requests to the vLLM engine",
|
&context.tokens,
|
||||||
num_waiting_requests
|
&context.params,
|
||||||
);
|
&context.stopping_params,
|
||||||
|
) {
|
||||||
let mut requests = Vec::with_capacity(num_waiting_requests);
|
|
||||||
waiting_requests.blocking_recv_many(&mut requests, num_waiting_requests);
|
|
||||||
|
|
||||||
for request in requests {
|
|
||||||
match engine.add_request(&request.tokens, &request.params, &request.stopping_params)
|
|
||||||
{
|
|
||||||
Ok(request_id) => {
|
Ok(request_id) => {
|
||||||
debug!("Successfully scheduled request {request_id}");
|
debug!("Successfully scheduled request {request_id}");
|
||||||
in_flight_requests.insert(request_id.to_string(), request);
|
in_flight_requests.insert(request_id.to_string(), context);
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Failed to schedule new request: {err}");
|
warn!("Failed to schedule new request: {err}");
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
Err(err) => match err {
|
||||||
|
RecvTimeoutError::Disconnected => break 'outer,
|
||||||
|
_ => {} // timeout all fine
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !in_flight_requests.is_empty() {
|
||||||
|
match engine.step() {
|
||||||
|
Ok(outputs) => outputs.iter().for_each(|output| {
|
||||||
|
let ctx = &in_flight_requests[&output.request_id];
|
||||||
|
|
||||||
|
// We only need to check on Err meaning the channel is not open anymore, so abort the request
|
||||||
|
if let Err(_) = ctx.stream.send(InferResult {}) {
|
||||||
|
debug!("Request {}'s channel dropped, aborting", &output.request_id);
|
||||||
|
in_flight_requests.remove(&output.request_id);
|
||||||
|
engine.abort_request(&output.request_id);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
Err(err) => {
|
||||||
|
error!("LLMEngine::step got an error: {err}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
engine.step();
|
|
||||||
|
spin_loop();
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Shutting down vLLM engine background loop");
|
info!("Shutting down vLLM engine background loop");
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
use crate::errors::VllmBackendError;
|
use crate::errors::VllmBackendError;
|
||||||
use crate::{sampling_params, tokens_prompt, TryToPyObject};
|
use crate::{sampling_params, tokens_prompt, TryToPyObject};
|
||||||
|
use pyo3::intern;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
|
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
|
||||||
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
|
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
|
||||||
use tracing::info;
|
use tracing::{info, instrument};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub struct EngineArgs {
|
pub struct EngineArgs {
|
||||||
@ -29,9 +30,9 @@ impl IntoPyDict for EngineArgs {
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
.as_any(),
|
.as_any(),
|
||||||
)
|
)
|
||||||
.expect("Failed to create Python Dict from EngineArgs")
|
.expect("Failed to create Python Dict from EngineArgs")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,29 +48,32 @@ impl TryToPyObject for SamplingParams<'_> {
|
|||||||
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
|
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
|
||||||
py,
|
py,
|
||||||
[
|
[
|
||||||
("seed", self.sampling_params.seed.into_py(py)),
|
(intern!(py, "seed"), self.sampling_params.seed.into_py(py)),
|
||||||
("n", 1.into_py(py)),
|
(intern!(py, "n"), 1.into_py(py)),
|
||||||
("top_k", self.sampling_params.top_k.into_py(py)),
|
(intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)),
|
||||||
("top_p", self.sampling_params.top_p.into_py(py)),
|
(intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)),
|
||||||
("temperature", self.sampling_params.temperature.into_py(py)),
|
|
||||||
(
|
(
|
||||||
"frequency_penalty",
|
intern!(py, "temperature"),
|
||||||
|
self.sampling_params.temperature.into_py(py),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
intern!(py, "frequency_penalty"),
|
||||||
self.sampling_params.frequency_penalty.into_py(py),
|
self.sampling_params.frequency_penalty.into_py(py),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"repetition_penalty",
|
intern!(py, "repetition_penalty"),
|
||||||
self.sampling_params.repetition_penalty.into_py(py),
|
self.sampling_params.repetition_penalty.into_py(py),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"ignore_eos",
|
intern!(py, "ignore_eos"),
|
||||||
self.stopping_params.ignore_eos_token.into_py(py),
|
self.stopping_params.ignore_eos_token.into_py(py),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"max_tokens",
|
intern!(py, "max_tokens"),
|
||||||
self.stopping_params.max_new_tokens.into_py(py),
|
self.stopping_params.max_new_tokens.into_py(py),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"stop",
|
intern!(py, "stop"),
|
||||||
PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(),
|
PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -81,6 +85,47 @@ impl TryToPyObject for SamplingParams<'_> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CompletionOutput {
|
||||||
|
pub index: usize,
|
||||||
|
pub text: String, // TODO: SmallString?
|
||||||
|
pub token_ids: Vec<u32>, // TODO: TinyVec?
|
||||||
|
pub logprobs: Option<Vec<f32>>, // TODO: TinyVec?
|
||||||
|
pub finish_reason: Option<String>, // lora_request: LATER
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RequestOutput {
|
||||||
|
pub request_id: String,
|
||||||
|
pub outputs: Vec<CompletionOutput>,
|
||||||
|
pub finished: bool,
|
||||||
|
// metrics: Vec<RequestMetrics> // TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'py> FromPyObject<'py> for CompletionOutput {
|
||||||
|
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
|
||||||
|
let py = ob.py();
|
||||||
|
Ok(Self {
|
||||||
|
index: ob.getattr(intern!(py, "index"))?.extract()?,
|
||||||
|
text: ob.getattr(intern!(py, "text"))?.extract()?,
|
||||||
|
token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?,
|
||||||
|
logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?,
|
||||||
|
finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'py> FromPyObject<'py> for RequestOutput {
|
||||||
|
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
|
||||||
|
let py = ob.py();
|
||||||
|
Ok(Self {
|
||||||
|
request_id: ob.getattr(intern!(py, "request_id"))?.extract()?,
|
||||||
|
outputs: ob.getattr(intern!(py, "outputs"))?.extract()?,
|
||||||
|
finished: ob.getattr(intern!(py, "finished"))?.extract()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct LlmEngine {
|
pub struct LlmEngine {
|
||||||
engine: PyObject,
|
engine: PyObject,
|
||||||
}
|
}
|
||||||
@ -115,14 +160,14 @@ impl LlmEngine {
|
|||||||
) -> Result<(), VllmBackendError> {
|
) -> Result<(), VllmBackendError> {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
// Create vllm.Tokens
|
// Create vllm.Tokens
|
||||||
let kwargs = [("prompt_token_ids", prompt)].into_py_dict_bound(py);
|
let kwargs = [(intern!(py, "prompt_token_ids"), prompt)].into_py_dict_bound(py);
|
||||||
let py_tokens_prompt_class = tokens_prompt(py);
|
let py_tokens_prompt_class = tokens_prompt(py);
|
||||||
let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?;
|
let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?;
|
||||||
let py_sampling_params = sampling_params.try_to_object(py)?;
|
let py_sampling_params = sampling_params.try_to_object(py)?;
|
||||||
|
|
||||||
self.engine.call_method1(
|
self.engine.call_method1(
|
||||||
py,
|
py,
|
||||||
"add_request",
|
intern!(py, "add_request"),
|
||||||
(
|
(
|
||||||
PyString::new_bound(py, request_id),
|
PyString::new_bound(py, request_id),
|
||||||
py_tokens_prompt,
|
py_tokens_prompt,
|
||||||
@ -130,18 +175,27 @@ impl LlmEngine {
|
|||||||
),
|
),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
self.engine.call_method0(py, "step")
|
self.engine.call_method0(py, intern!(py, "step"))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn py_step(&self) -> Result<Vec<RequestOutput>, VllmBackendError> {
|
||||||
|
Ok(Python::with_gil(|py| {
|
||||||
|
self.engine
|
||||||
|
.call_method0(py, intern!(py, "step"))?
|
||||||
|
.extract::<Vec<RequestOutput>>(py)
|
||||||
|
})?)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
|
pub fn from_engine_args(args: EngineArgs) -> Result<LlmEngine, VllmBackendError> {
|
||||||
let engine = Self::py_from_engine_args(args)?;
|
let engine = Self::py_from_engine_args(args)?;
|
||||||
|
|
||||||
Ok(Self { engine })
|
Ok(Self { engine })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
pub fn add_request(
|
pub fn add_request(
|
||||||
&self,
|
&self,
|
||||||
prompt: &[u32],
|
prompt: &[u32],
|
||||||
@ -159,5 +213,11 @@ impl LlmEngine {
|
|||||||
Ok(request_id)
|
Ok(request_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn step(&mut self) {}
|
#[instrument(skip_all)]
|
||||||
|
pub fn abort_request(&self, _request_id: &str) {}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub fn step(&mut self) -> Result<Vec<RequestOutput>, VllmBackendError> {
|
||||||
|
self.py_step()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user