From 003163a2b9a0e7e30d20bcc96a71955be850649b Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 30 Jan 2025 16:12:52 +0100 Subject: [PATCH] backend(vllm): map ResultOutput to InferStreamResponse to stream back to the client --- backends/vllm/src/backend.rs | 103 ++++++++++++++++++++++++++++------- backends/vllm/src/engine.rs | 46 ++++++++++++---- backends/vllm/src/lib.rs | 3 + 3 files changed, 123 insertions(+), 29 deletions(-) diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs index 46419279..15092d31 100644 --- a/backends/vllm/src/backend.rs +++ b/backends/vllm/src/backend.rs @@ -1,25 +1,81 @@ +use crate::engine::RequestOutput; use crate::errors::VllmBackendError; -use crate::{EngineArgs, LlmEngine}; +use crate::{EngineArgs, LlmEngine, STARTUP_INSTANT}; use async_trait::async_trait; -use crossbeam_channel::internal::SelectHandle; use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::hint::spin_loop; -use std::sync::atomic::AtomicBool; use std::sync::Arc; -use std::thread::{spawn, JoinHandle}; -use std::time::Duration; -use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use std::thread::spawn; +use std::time::{Duration, Instant as StdInstant, UNIX_EPOCH}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::{ ValidGenerateRequest, ValidParameters, ValidStoppingParameters, }; -use text_generation_router::Token; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use text_generation_router::{FinishReason, Token}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; type InferResult = Result; +impl TryFrom<&RequestOutput> for InferStreamResponse { + type Error = InferError; + + fn try_from(output: &RequestOutput) -> Result { + if let Some(last) = output.outputs.last() { + if let Some(token_id) = last.token_ids.last() { + let token = Token { + id: *token_id, + text: last.text.clone(), + // logprob: last.logprobs[0], + logprob: 0.0f32, + special: false, + }; + + if !output.finished { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) + } else { + // TODO: Let's see how to request metrics + // let metrics = output + // .metrics + // .last() + // .expect("metrics should be set if token was unpacked"); + // + // debug!("Request: {} -> {metrics:?}", &output.request_id); + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: last.text.clone(), + generated_tokens: last.token_ids.len() as u32, + finish_reason: last + .finish_reason + .as_ref() + .map(|reason| match reason.as_str() { + "length" => FinishReason::Length, + _ => FinishReason::StopSequence, + }) + .unwrap(), + seed: None, + }, + start: Instant::now(), + queued: Instant::now(), + }) + } + } else { + Err(InferError::GenerationError("No token returned".to_string())) + } + } else { + Err(InferError::GenerationError("No token returned".to_string())) + } + } +} + struct VllmRequestContext { tokens: Arc>, params: ValidParameters, @@ -35,7 +91,7 @@ impl VllmBackend { pub fn from_engine_args(args: EngineArgs) -> Result { let engine = LlmEngine::from_engine_args(args)?; let (sender, receiver) = unbounded(); - let looper = spawn(|| engine_background_loop(engine, receiver)); + let _ = spawn(|| engine_background_loop(engine, receiver)); Ok(Self { waiting_requests: sender, }) @@ -71,10 +127,7 @@ impl Backend for VllmBackend { } } -fn engine_background_loop( - mut engine: LlmEngine, - mut waiting_requests: Receiver, -) { +fn engine_background_loop(mut engine: LlmEngine, waiting_requests: Receiver) { info!("Starting vLLM engine background loop"); static DURATION_100_MS: Duration = Duration::from_millis(100); let mut in_flight_requests = HashMap::with_capacity(256); @@ -101,20 +154,32 @@ fn engine_background_loop( } } + // If there are tracked requests, let's pick the intermediate results if !in_flight_requests.is_empty() { match engine.step() { Ok(outputs) => outputs.iter().for_each(|output| { - let ctx = &in_flight_requests[&output.request_id]; + // Retrieve the context + { + let ctx = &in_flight_requests[&output.request_id]; + let result = InferStreamResponse::try_from(output); - // We only need to check on Err meaning the channel is not open anymore, so abort the request - if let Err(_) = ctx.stream.send(InferResult {}) { - debug!("Request {}'s channel dropped, aborting", &output.request_id); + // We only need to check on Err meaning the channel is not open anymore, so abort the request + if let Err(_) = ctx.stream.send(result) { + debug!("Request {}'s channel dropped, aborting", &output.request_id); + in_flight_requests.remove(&output.request_id); + engine.abort_request(&output.request_id); + } + } + + // Drop the request if done + if output.finished { in_flight_requests.remove(&output.request_id); - engine.abort_request(&output.request_id); } }), Err(err) => { error!("LLMEngine::step got an error: {err}"); + // TODO: Shall we exit from here? We can't link this to any particular user, + // it's Rust <> Python FFI which failed } } } diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs index f3b6a761..dcbff82f 100644 --- a/backends/vllm/src/engine.rs +++ b/backends/vllm/src/engine.rs @@ -2,6 +2,7 @@ use crate::errors::VllmBackendError; use crate::{sampling_params, tokens_prompt, TryToPyObject}; use pyo3::intern; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; use tracing::{info, instrument}; @@ -36,6 +37,8 @@ impl IntoPyDict for EngineArgs { } } +static FINAL_OUTPUT_ONLY: GILOnceCell = GILOnceCell::new(); + pub struct SamplingParams<'a> { sampling_params: &'a ValidParameters, stopping_params: &'a ValidStoppingParameters, @@ -48,8 +51,10 @@ impl TryToPyObject for SamplingParams<'_> { let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( py, [ - (intern!(py, "seed"), self.sampling_params.seed.into_py(py)), + (intern!(py, "output_kind"), 2.into_py(py)), + (intern!(py, "logprobs"), 1.into_py(py)), (intern!(py, "n"), 1.into_py(py)), + (intern!(py, "seed"), self.sampling_params.seed.into_py(py)), (intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)), (intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)), ( @@ -86,20 +91,40 @@ impl TryToPyObject for SamplingParams<'_> { } #[derive(Debug)] -pub struct CompletionOutput { - pub index: usize, - pub text: String, // TODO: SmallString? - pub token_ids: Vec, // TODO: TinyVec? - pub logprobs: Option>, // TODO: TinyVec? +pub(crate) struct CompletionOutput { + pub token_ids: Vec, // TODO: TinyVec? + pub text: String, // TODO: SmallString? + // pub logprobs: Vec, // TODO: TinyVec? pub finish_reason: Option, // lora_request: LATER + pub index: usize, +} + +#[derive(Debug, Copy, Clone)] +pub(crate) struct RequestMetrics { + pub arrival_time: f32, + pub first_scheduled_time: f32, + pub first_token_time: f32, + pub time_in_queue: f32, +} + +impl<'py> FromPyObject<'py> for RequestMetrics { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + arrival_time: ob.getattr(intern!(py, "arrival_time"))?.extract()?, + first_scheduled_time: ob.getattr(intern!(py, "first_scheduled_time"))?.extract()?, + first_token_time: ob.getattr(intern!(py, "first_token_time"))?.extract()?, + time_in_queue: ob.getattr(intern!(py, "time_in_queue"))?.extract()?, + }) + } } #[derive(Debug)] -pub struct RequestOutput { - pub request_id: String, +pub(crate) struct RequestOutput { pub outputs: Vec, + // pub metrics: Vec, + pub request_id: String, pub finished: bool, - // metrics: Vec // TODO } impl<'py> FromPyObject<'py> for CompletionOutput { @@ -109,7 +134,7 @@ impl<'py> FromPyObject<'py> for CompletionOutput { index: ob.getattr(intern!(py, "index"))?.extract()?, text: ob.getattr(intern!(py, "text"))?.extract()?, token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?, - logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, + // logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?, }) } @@ -122,6 +147,7 @@ impl<'py> FromPyObject<'py> for RequestOutput { request_id: ob.getattr(intern!(py, "request_id"))?.extract()?, outputs: ob.getattr(intern!(py, "outputs"))?.extract()?, finished: ob.getattr(intern!(py, "finished"))?.extract()?, + // metrics: ob.getattr(intern!(py, "metrics"))?.extract()?, }) } } diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs index 12c910df..4bd4f434 100644 --- a/backends/vllm/src/lib.rs +++ b/backends/vllm/src/lib.rs @@ -9,6 +9,9 @@ use pyo3::prelude::PyAnyMethods; use pyo3::sync::GILOnceCell; use pyo3::types::PyModule; use pyo3::{Py, PyAny, PyErr, PyObject, Python}; +use tokio::time::Instant; + +pub(crate) const STARTUP_INSTANT: Instant = Instant::now(); static PY_TOKENS_PROMPT_CLASS: GILOnceCell> = GILOnceCell::new(); static PY_SAMPLING_PARAMS_CLASS: GILOnceCell> = GILOnceCell::new();