backend(vllm): map ResultOutput to InferStreamResponse to stream back to the client

This commit is contained in:
Morgan Funtowicz 2025-01-30 16:12:52 +01:00
parent 32dffcff60
commit 003163a2b9
3 changed files with 123 additions and 29 deletions

View File

@ -1,25 +1,81 @@
use crate::engine::RequestOutput;
use crate::errors::VllmBackendError; use crate::errors::VllmBackendError;
use crate::{EngineArgs, LlmEngine}; use crate::{EngineArgs, LlmEngine, STARTUP_INSTANT};
use async_trait::async_trait; use async_trait::async_trait;
use crossbeam_channel::internal::SelectHandle;
use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender}; use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender};
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::hint::spin_loop; use std::hint::spin_loop;
use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use std::thread::{spawn, JoinHandle}; use std::thread::spawn;
use std::time::Duration; use std::time::{Duration, Instant as StdInstant, UNIX_EPOCH};
use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{ use text_generation_router::validation::{
ValidGenerateRequest, ValidParameters, ValidStoppingParameters, ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
}; };
use text_generation_router::Token; use text_generation_router::{FinishReason, Token};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
type InferResult = Result<InferStreamResponse, InferError>; type InferResult = Result<InferStreamResponse, InferError>;
impl TryFrom<&RequestOutput> for InferStreamResponse {
type Error = InferError;
fn try_from(output: &RequestOutput) -> Result<Self, Self::Error> {
if let Some(last) = output.outputs.last() {
if let Some(token_id) = last.token_ids.last() {
let token = Token {
id: *token_id,
text: last.text.clone(),
// logprob: last.logprobs[0],
logprob: 0.0f32,
special: false,
};
if !output.finished {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
} else {
// TODO: Let's see how to request metrics
// let metrics = output
// .metrics
// .last()
// .expect("metrics should be set if token was unpacked");
//
// debug!("Request: {} -> {metrics:?}", &output.request_id);
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: last.text.clone(),
generated_tokens: last.token_ids.len() as u32,
finish_reason: last
.finish_reason
.as_ref()
.map(|reason| match reason.as_str() {
"length" => FinishReason::Length,
_ => FinishReason::StopSequence,
})
.unwrap(),
seed: None,
},
start: Instant::now(),
queued: Instant::now(),
})
}
} else {
Err(InferError::GenerationError("No token returned".to_string()))
}
} else {
Err(InferError::GenerationError("No token returned".to_string()))
}
}
}
struct VllmRequestContext { struct VllmRequestContext {
tokens: Arc<Vec<u32>>, tokens: Arc<Vec<u32>>,
params: ValidParameters, params: ValidParameters,
@ -35,7 +91,7 @@ impl VllmBackend {
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> { pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
let engine = LlmEngine::from_engine_args(args)?; let engine = LlmEngine::from_engine_args(args)?;
let (sender, receiver) = unbounded(); let (sender, receiver) = unbounded();
let looper = spawn(|| engine_background_loop(engine, receiver)); let _ = spawn(|| engine_background_loop(engine, receiver));
Ok(Self { Ok(Self {
waiting_requests: sender, waiting_requests: sender,
}) })
@ -71,10 +127,7 @@ impl Backend for VllmBackend {
} }
} }
fn engine_background_loop( fn engine_background_loop(mut engine: LlmEngine, waiting_requests: Receiver<VllmRequestContext>) {
mut engine: LlmEngine,
mut waiting_requests: Receiver<VllmRequestContext>,
) {
info!("Starting vLLM engine background loop"); info!("Starting vLLM engine background loop");
static DURATION_100_MS: Duration = Duration::from_millis(100); static DURATION_100_MS: Duration = Duration::from_millis(100);
let mut in_flight_requests = HashMap::with_capacity(256); let mut in_flight_requests = HashMap::with_capacity(256);
@ -101,20 +154,32 @@ fn engine_background_loop(
} }
} }
// If there are tracked requests, let's pick the intermediate results
if !in_flight_requests.is_empty() { if !in_flight_requests.is_empty() {
match engine.step() { match engine.step() {
Ok(outputs) => outputs.iter().for_each(|output| { Ok(outputs) => outputs.iter().for_each(|output| {
// Retrieve the context
{
let ctx = &in_flight_requests[&output.request_id]; let ctx = &in_flight_requests[&output.request_id];
let result = InferStreamResponse::try_from(output);
// We only need to check on Err meaning the channel is not open anymore, so abort the request // We only need to check on Err meaning the channel is not open anymore, so abort the request
if let Err(_) = ctx.stream.send(InferResult {}) { if let Err(_) = ctx.stream.send(result) {
debug!("Request {}'s channel dropped, aborting", &output.request_id); debug!("Request {}'s channel dropped, aborting", &output.request_id);
in_flight_requests.remove(&output.request_id); in_flight_requests.remove(&output.request_id);
engine.abort_request(&output.request_id); engine.abort_request(&output.request_id);
} }
}
// Drop the request if done
if output.finished {
in_flight_requests.remove(&output.request_id);
}
}), }),
Err(err) => { Err(err) => {
error!("LLMEngine::step got an error: {err}"); error!("LLMEngine::step got an error: {err}");
// TODO: Shall we exit from here? We can't link this to any particular user,
// it's Rust <> Python FFI which failed
} }
} }
} }

View File

@ -2,6 +2,7 @@ use crate::errors::VllmBackendError;
use crate::{sampling_params, tokens_prompt, TryToPyObject}; use crate::{sampling_params, tokens_prompt, TryToPyObject};
use pyo3::intern; use pyo3::intern;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
use tracing::{info, instrument}; use tracing::{info, instrument};
@ -36,6 +37,8 @@ impl IntoPyDict for EngineArgs {
} }
} }
static FINAL_OUTPUT_ONLY: GILOnceCell<PyObject> = GILOnceCell::new();
pub struct SamplingParams<'a> { pub struct SamplingParams<'a> {
sampling_params: &'a ValidParameters, sampling_params: &'a ValidParameters,
stopping_params: &'a ValidStoppingParameters, stopping_params: &'a ValidStoppingParameters,
@ -48,8 +51,10 @@ impl TryToPyObject for SamplingParams<'_> {
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
py, py,
[ [
(intern!(py, "seed"), self.sampling_params.seed.into_py(py)), (intern!(py, "output_kind"), 2.into_py(py)),
(intern!(py, "logprobs"), 1.into_py(py)),
(intern!(py, "n"), 1.into_py(py)), (intern!(py, "n"), 1.into_py(py)),
(intern!(py, "seed"), self.sampling_params.seed.into_py(py)),
(intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)), (intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)),
(intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)), (intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)),
( (
@ -86,20 +91,40 @@ impl TryToPyObject for SamplingParams<'_> {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct CompletionOutput { pub(crate) struct CompletionOutput {
pub index: usize,
pub text: String, // TODO: SmallString?
pub token_ids: Vec<u32>, // TODO: TinyVec? pub token_ids: Vec<u32>, // TODO: TinyVec?
pub logprobs: Option<Vec<f32>>, // TODO: TinyVec? pub text: String, // TODO: SmallString?
// pub logprobs: Vec<f32>, // TODO: TinyVec?
pub finish_reason: Option<String>, // lora_request: LATER pub finish_reason: Option<String>, // lora_request: LATER
pub index: usize,
}
#[derive(Debug, Copy, Clone)]
pub(crate) struct RequestMetrics {
pub arrival_time: f32,
pub first_scheduled_time: f32,
pub first_token_time: f32,
pub time_in_queue: f32,
}
impl<'py> FromPyObject<'py> for RequestMetrics {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let py = ob.py();
Ok(Self {
arrival_time: ob.getattr(intern!(py, "arrival_time"))?.extract()?,
first_scheduled_time: ob.getattr(intern!(py, "first_scheduled_time"))?.extract()?,
first_token_time: ob.getattr(intern!(py, "first_token_time"))?.extract()?,
time_in_queue: ob.getattr(intern!(py, "time_in_queue"))?.extract()?,
})
}
} }
#[derive(Debug)] #[derive(Debug)]
pub struct RequestOutput { pub(crate) struct RequestOutput {
pub request_id: String,
pub outputs: Vec<CompletionOutput>, pub outputs: Vec<CompletionOutput>,
// pub metrics: Vec<RequestMetrics>,
pub request_id: String,
pub finished: bool, pub finished: bool,
// metrics: Vec<RequestMetrics> // TODO
} }
impl<'py> FromPyObject<'py> for CompletionOutput { impl<'py> FromPyObject<'py> for CompletionOutput {
@ -109,7 +134,7 @@ impl<'py> FromPyObject<'py> for CompletionOutput {
index: ob.getattr(intern!(py, "index"))?.extract()?, index: ob.getattr(intern!(py, "index"))?.extract()?,
text: ob.getattr(intern!(py, "text"))?.extract()?, text: ob.getattr(intern!(py, "text"))?.extract()?,
token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?, token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?,
logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, // logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?,
finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?, finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?,
}) })
} }
@ -122,6 +147,7 @@ impl<'py> FromPyObject<'py> for RequestOutput {
request_id: ob.getattr(intern!(py, "request_id"))?.extract()?, request_id: ob.getattr(intern!(py, "request_id"))?.extract()?,
outputs: ob.getattr(intern!(py, "outputs"))?.extract()?, outputs: ob.getattr(intern!(py, "outputs"))?.extract()?,
finished: ob.getattr(intern!(py, "finished"))?.extract()?, finished: ob.getattr(intern!(py, "finished"))?.extract()?,
// metrics: ob.getattr(intern!(py, "metrics"))?.extract()?,
}) })
} }
} }

View File

@ -9,6 +9,9 @@ use pyo3::prelude::PyAnyMethods;
use pyo3::sync::GILOnceCell; use pyo3::sync::GILOnceCell;
use pyo3::types::PyModule; use pyo3::types::PyModule;
use pyo3::{Py, PyAny, PyErr, PyObject, Python}; use pyo3::{Py, PyAny, PyErr, PyObject, Python};
use tokio::time::Instant;
pub(crate) const STARTUP_INSTANT: Instant = Instant::now();
static PY_TOKENS_PROMPT_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new(); static PY_TOKENS_PROMPT_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
static PY_SAMPLING_PARAMS_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new(); static PY_SAMPLING_PARAMS_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();