mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
backend(vllm): map ResultOutput to InferStreamResponse to stream back to the client
This commit is contained in:
parent
32dffcff60
commit
003163a2b9
@ -1,25 +1,81 @@
|
|||||||
|
use crate::engine::RequestOutput;
|
||||||
use crate::errors::VllmBackendError;
|
use crate::errors::VllmBackendError;
|
||||||
use crate::{EngineArgs, LlmEngine};
|
use crate::{EngineArgs, LlmEngine, STARTUP_INSTANT};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use crossbeam_channel::internal::SelectHandle;
|
|
||||||
use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender};
|
use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::HashMap;
|
||||||
use std::hint::spin_loop;
|
use std::hint::spin_loop;
|
||||||
use std::sync::atomic::AtomicBool;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread::{spawn, JoinHandle};
|
use std::thread::spawn;
|
||||||
use std::time::Duration;
|
use std::time::{Duration, Instant as StdInstant, UNIX_EPOCH};
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{
|
||||||
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use text_generation_router::Token;
|
use text_generation_router::{FinishReason, Token};
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
type InferResult = Result<InferStreamResponse, InferError>;
|
type InferResult = Result<InferStreamResponse, InferError>;
|
||||||
|
|
||||||
|
impl TryFrom<&RequestOutput> for InferStreamResponse {
|
||||||
|
type Error = InferError;
|
||||||
|
|
||||||
|
fn try_from(output: &RequestOutput) -> Result<Self, Self::Error> {
|
||||||
|
if let Some(last) = output.outputs.last() {
|
||||||
|
if let Some(token_id) = last.token_ids.last() {
|
||||||
|
let token = Token {
|
||||||
|
id: *token_id,
|
||||||
|
text: last.text.clone(),
|
||||||
|
// logprob: last.logprobs[0],
|
||||||
|
logprob: 0.0f32,
|
||||||
|
special: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
if !output.finished {
|
||||||
|
Ok(InferStreamResponse::Intermediate {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// TODO: Let's see how to request metrics
|
||||||
|
// let metrics = output
|
||||||
|
// .metrics
|
||||||
|
// .last()
|
||||||
|
// .expect("metrics should be set if token was unpacked");
|
||||||
|
//
|
||||||
|
// debug!("Request: {} -> {metrics:?}", &output.request_id);
|
||||||
|
Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: GeneratedText {
|
||||||
|
text: last.text.clone(),
|
||||||
|
generated_tokens: last.token_ids.len() as u32,
|
||||||
|
finish_reason: last
|
||||||
|
.finish_reason
|
||||||
|
.as_ref()
|
||||||
|
.map(|reason| match reason.as_str() {
|
||||||
|
"length" => FinishReason::Length,
|
||||||
|
_ => FinishReason::StopSequence,
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
seed: None,
|
||||||
|
},
|
||||||
|
start: Instant::now(),
|
||||||
|
queued: Instant::now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(InferError::GenerationError("No token returned".to_string()))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(InferError::GenerationError("No token returned".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct VllmRequestContext {
|
struct VllmRequestContext {
|
||||||
tokens: Arc<Vec<u32>>,
|
tokens: Arc<Vec<u32>>,
|
||||||
params: ValidParameters,
|
params: ValidParameters,
|
||||||
@ -35,7 +91,7 @@ impl VllmBackend {
|
|||||||
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
pub fn from_engine_args(args: EngineArgs) -> Result<VllmBackend, VllmBackendError> {
|
||||||
let engine = LlmEngine::from_engine_args(args)?;
|
let engine = LlmEngine::from_engine_args(args)?;
|
||||||
let (sender, receiver) = unbounded();
|
let (sender, receiver) = unbounded();
|
||||||
let looper = spawn(|| engine_background_loop(engine, receiver));
|
let _ = spawn(|| engine_background_loop(engine, receiver));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
waiting_requests: sender,
|
waiting_requests: sender,
|
||||||
})
|
})
|
||||||
@ -71,10 +127,7 @@ impl Backend for VllmBackend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn engine_background_loop(
|
fn engine_background_loop(mut engine: LlmEngine, waiting_requests: Receiver<VllmRequestContext>) {
|
||||||
mut engine: LlmEngine,
|
|
||||||
mut waiting_requests: Receiver<VllmRequestContext>,
|
|
||||||
) {
|
|
||||||
info!("Starting vLLM engine background loop");
|
info!("Starting vLLM engine background loop");
|
||||||
static DURATION_100_MS: Duration = Duration::from_millis(100);
|
static DURATION_100_MS: Duration = Duration::from_millis(100);
|
||||||
let mut in_flight_requests = HashMap::with_capacity(256);
|
let mut in_flight_requests = HashMap::with_capacity(256);
|
||||||
@ -101,20 +154,32 @@ fn engine_background_loop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there are tracked requests, let's pick the intermediate results
|
||||||
if !in_flight_requests.is_empty() {
|
if !in_flight_requests.is_empty() {
|
||||||
match engine.step() {
|
match engine.step() {
|
||||||
Ok(outputs) => outputs.iter().for_each(|output| {
|
Ok(outputs) => outputs.iter().for_each(|output| {
|
||||||
|
// Retrieve the context
|
||||||
|
{
|
||||||
let ctx = &in_flight_requests[&output.request_id];
|
let ctx = &in_flight_requests[&output.request_id];
|
||||||
|
let result = InferStreamResponse::try_from(output);
|
||||||
|
|
||||||
// We only need to check on Err meaning the channel is not open anymore, so abort the request
|
// We only need to check on Err meaning the channel is not open anymore, so abort the request
|
||||||
if let Err(_) = ctx.stream.send(InferResult {}) {
|
if let Err(_) = ctx.stream.send(result) {
|
||||||
debug!("Request {}'s channel dropped, aborting", &output.request_id);
|
debug!("Request {}'s channel dropped, aborting", &output.request_id);
|
||||||
in_flight_requests.remove(&output.request_id);
|
in_flight_requests.remove(&output.request_id);
|
||||||
engine.abort_request(&output.request_id);
|
engine.abort_request(&output.request_id);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop the request if done
|
||||||
|
if output.finished {
|
||||||
|
in_flight_requests.remove(&output.request_id);
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("LLMEngine::step got an error: {err}");
|
error!("LLMEngine::step got an error: {err}");
|
||||||
|
// TODO: Shall we exit from here? We can't link this to any particular user,
|
||||||
|
// it's Rust <> Python FFI which failed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ use crate::errors::VllmBackendError;
|
|||||||
use crate::{sampling_params, tokens_prompt, TryToPyObject};
|
use crate::{sampling_params, tokens_prompt, TryToPyObject};
|
||||||
use pyo3::intern;
|
use pyo3::intern;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::sync::GILOnceCell;
|
||||||
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
|
use pyo3::types::{IntoPyDict, PyDict, PyList, PyString};
|
||||||
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
|
use text_generation_router::validation::{ValidParameters, ValidStoppingParameters};
|
||||||
use tracing::{info, instrument};
|
use tracing::{info, instrument};
|
||||||
@ -36,6 +37,8 @@ impl IntoPyDict for EngineArgs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FINAL_OUTPUT_ONLY: GILOnceCell<PyObject> = GILOnceCell::new();
|
||||||
|
|
||||||
pub struct SamplingParams<'a> {
|
pub struct SamplingParams<'a> {
|
||||||
sampling_params: &'a ValidParameters,
|
sampling_params: &'a ValidParameters,
|
||||||
stopping_params: &'a ValidStoppingParameters,
|
stopping_params: &'a ValidStoppingParameters,
|
||||||
@ -48,8 +51,10 @@ impl TryToPyObject for SamplingParams<'_> {
|
|||||||
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
|
let kwargs = PyDict::from_sequence_bound(&PyList::new_bound(
|
||||||
py,
|
py,
|
||||||
[
|
[
|
||||||
(intern!(py, "seed"), self.sampling_params.seed.into_py(py)),
|
(intern!(py, "output_kind"), 2.into_py(py)),
|
||||||
|
(intern!(py, "logprobs"), 1.into_py(py)),
|
||||||
(intern!(py, "n"), 1.into_py(py)),
|
(intern!(py, "n"), 1.into_py(py)),
|
||||||
|
(intern!(py, "seed"), self.sampling_params.seed.into_py(py)),
|
||||||
(intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)),
|
(intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)),
|
||||||
(intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)),
|
(intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)),
|
||||||
(
|
(
|
||||||
@ -86,20 +91,40 @@ impl TryToPyObject for SamplingParams<'_> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CompletionOutput {
|
pub(crate) struct CompletionOutput {
|
||||||
pub index: usize,
|
|
||||||
pub text: String, // TODO: SmallString?
|
|
||||||
pub token_ids: Vec<u32>, // TODO: TinyVec?
|
pub token_ids: Vec<u32>, // TODO: TinyVec?
|
||||||
pub logprobs: Option<Vec<f32>>, // TODO: TinyVec?
|
pub text: String, // TODO: SmallString?
|
||||||
|
// pub logprobs: Vec<f32>, // TODO: TinyVec?
|
||||||
pub finish_reason: Option<String>, // lora_request: LATER
|
pub finish_reason: Option<String>, // lora_request: LATER
|
||||||
|
pub index: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub(crate) struct RequestMetrics {
|
||||||
|
pub arrival_time: f32,
|
||||||
|
pub first_scheduled_time: f32,
|
||||||
|
pub first_token_time: f32,
|
||||||
|
pub time_in_queue: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'py> FromPyObject<'py> for RequestMetrics {
|
||||||
|
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
|
||||||
|
let py = ob.py();
|
||||||
|
Ok(Self {
|
||||||
|
arrival_time: ob.getattr(intern!(py, "arrival_time"))?.extract()?,
|
||||||
|
first_scheduled_time: ob.getattr(intern!(py, "first_scheduled_time"))?.extract()?,
|
||||||
|
first_token_time: ob.getattr(intern!(py, "first_token_time"))?.extract()?,
|
||||||
|
time_in_queue: ob.getattr(intern!(py, "time_in_queue"))?.extract()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct RequestOutput {
|
pub(crate) struct RequestOutput {
|
||||||
pub request_id: String,
|
|
||||||
pub outputs: Vec<CompletionOutput>,
|
pub outputs: Vec<CompletionOutput>,
|
||||||
|
// pub metrics: Vec<RequestMetrics>,
|
||||||
|
pub request_id: String,
|
||||||
pub finished: bool,
|
pub finished: bool,
|
||||||
// metrics: Vec<RequestMetrics> // TODO
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'py> FromPyObject<'py> for CompletionOutput {
|
impl<'py> FromPyObject<'py> for CompletionOutput {
|
||||||
@ -109,7 +134,7 @@ impl<'py> FromPyObject<'py> for CompletionOutput {
|
|||||||
index: ob.getattr(intern!(py, "index"))?.extract()?,
|
index: ob.getattr(intern!(py, "index"))?.extract()?,
|
||||||
text: ob.getattr(intern!(py, "text"))?.extract()?,
|
text: ob.getattr(intern!(py, "text"))?.extract()?,
|
||||||
token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?,
|
token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?,
|
||||||
logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?,
|
// logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?,
|
||||||
finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?,
|
finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -122,6 +147,7 @@ impl<'py> FromPyObject<'py> for RequestOutput {
|
|||||||
request_id: ob.getattr(intern!(py, "request_id"))?.extract()?,
|
request_id: ob.getattr(intern!(py, "request_id"))?.extract()?,
|
||||||
outputs: ob.getattr(intern!(py, "outputs"))?.extract()?,
|
outputs: ob.getattr(intern!(py, "outputs"))?.extract()?,
|
||||||
finished: ob.getattr(intern!(py, "finished"))?.extract()?,
|
finished: ob.getattr(intern!(py, "finished"))?.extract()?,
|
||||||
|
// metrics: ob.getattr(intern!(py, "metrics"))?.extract()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,9 @@ use pyo3::prelude::PyAnyMethods;
|
|||||||
use pyo3::sync::GILOnceCell;
|
use pyo3::sync::GILOnceCell;
|
||||||
use pyo3::types::PyModule;
|
use pyo3::types::PyModule;
|
||||||
use pyo3::{Py, PyAny, PyErr, PyObject, Python};
|
use pyo3::{Py, PyAny, PyErr, PyObject, Python};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
pub(crate) const STARTUP_INSTANT: Instant = Instant::now();
|
||||||
|
|
||||||
static PY_TOKENS_PROMPT_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
static PY_TOKENS_PROMPT_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
||||||
static PY_SAMPLING_PARAMS_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
static PY_SAMPLING_PARAMS_CLASS: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
|
||||||
|
Loading…
Reference in New Issue
Block a user