diff --git a/Cargo.lock b/Cargo.lock index 915de0d5..0f3ff20c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -594,7 +594,7 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1629,7 +1629,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "ureq", ] @@ -2085,7 +2085,7 @@ checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17" dependencies = [ "opentelemetry 0.20.0", "opentelemetry-otlp", - "thiserror", + "thiserror 1.0.69", "tracing", "tracing-opentelemetry 0.21.0", ] @@ -2421,7 +2421,7 @@ dependencies = [ "metrics", "metrics-util", "quanta", - "thiserror", + "thiserror 1.0.69", "tokio", "tracing", ] @@ -2552,7 +2552,7 @@ dependencies = [ "futures", "pin-project", "rand", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-util", "tracing", @@ -2604,7 +2604,7 @@ dependencies = [ "rustls-pemfile", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-retry", "tokio-util", @@ -2908,7 +2908,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.69", "urlencoding", ] @@ -2926,7 +2926,7 @@ dependencies = [ "opentelemetry_api", "opentelemetry_sdk 0.20.0", "prost 0.11.9", - "thiserror", + "thiserror 1.0.69", "tokio", "tonic 0.9.2", ] @@ -2964,7 +2964,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.69", "urlencoding", ] @@ -2986,7 +2986,7 @@ dependencies = [ "rand", "regex", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-stream", ] @@ -3008,7 +3008,7 @@ dependencies = [ "ordered-float 4.5.0", "percent-encoding", "rand", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -3545,7 +3545,7 @@ dependencies = [ "rand_chacha", "simd_helpers", "system-deps", - "thiserror", + "thiserror 1.0.69", "v_frame", "wasm-bindgen", ] @@ -3622,7 +3622,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -4487,13 +4487,31 @@ dependencies = [ "pkg-config", "pyo3", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", "tracing", ] +[[package]] +name = "text-generation-backends-vllm" +version = "3.0.2-dev0" +dependencies = [ + "async-trait", + "clap 4.5.21", + "crossbeam-channel", + "log", + "pyo3", + "text-generation-router", + "thiserror 2.0.11", + "tokio", + "tokio-stream", + "tracing", + "tracing-subscriber", + "uuid", +] + [[package]] name = "text-generation-benchmark" version = "3.1.1-dev0" @@ -4507,7 +4525,7 @@ dependencies = [ "serde_json", "tabled", "text-generation-client", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tracing", @@ -4524,7 +4542,7 @@ dependencies = [ "grpc-metadata", "prost 0.12.6", "prost-build", - "thiserror", + "thiserror 1.0.69", "tokio", "tonic 0.10.2", "tonic-build", @@ -4547,7 +4565,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tracing", "tracing-subscriber", "vergen", @@ -4590,7 +4608,7 @@ dependencies = [ "serde", "serde_json", "sysinfo", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4639,7 +4657,7 @@ dependencies = [ "serde_json", "slotmap", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4690,7 +4708,7 @@ dependencies = [ "serde_json", "slotmap", "text-generation-router", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-stream", @@ -4720,7 +4738,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl 2.0.11", ] [[package]] @@ -4734,6 +4761,17 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -4835,7 +4873,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror", + "thiserror 1.0.69", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", diff --git a/Cargo.toml b/Cargo.toml index 6fd4b51d..ff4a5f5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,15 +5,18 @@ members = [ "backends/v3", "backends/grpc-metadata", "backends/trtllm", + "backends/vllm", "launcher", - "router" + "router", ] + default-members = [ "benchmark", "backends/v2", "backends/v3", "backends/grpc-metadata", # "backends/trtllm", + # "backends/vllm", "launcher", "router" ] @@ -33,7 +36,7 @@ metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } -pyo3 = { version = "0.22.2", features = ["auto-initialize"] } +pyo3 = { version = "0.22", features = ["auto-initialize"] } [profile.release] incremental = true diff --git a/backends/vllm/Cargo.toml b/backends/vllm/Cargo.toml new file mode 100644 index 00000000..c77f4562 --- /dev/null +++ b/backends/vllm/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "text-generation-backends-vllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-trait = "0.1.83" +clap = { version = "4.5.21", features = ["derive"] } +crossbeam-channel = "0.5" +pyo3 = { workspace = true } +text-generation-router = { path = "../../router" } +thiserror = "2.0" +tokio = { version = "1.43", features = ["full"] } +tokio-stream = "0.1" +uuid = { version = "1.11.0", features = ["v4"] } +log = "0.4.22" +tracing = "0.1.40" +tracing-subscriber = "0.3.18" diff --git a/backends/vllm/src/backend.rs b/backends/vllm/src/backend.rs new file mode 100644 index 00000000..16d7d4ac --- /dev/null +++ b/backends/vllm/src/backend.rs @@ -0,0 +1,197 @@ +use crate::engine::RequestOutput; +use crate::errors::VllmBackendError; +use crate::{EngineArgs, LlmEngine, STARTUP_INSTANT}; +use async_trait::async_trait; +use crossbeam_channel::{unbounded, Receiver, RecvTimeoutError, Sender}; +use std::collections::HashMap; +use std::hint::spin_loop; +use std::sync::Arc; +use std::thread::spawn; +use std::time::{Duration, Instant as StdInstant, UNIX_EPOCH}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::{ + ValidGenerateRequest, ValidParameters, ValidStoppingParameters, +}; +use text_generation_router::{FinishReason, Token}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{debug, error, info, warn}; + +type InferResult = Result; + +impl TryFrom<&RequestOutput> for InferStreamResponse { + type Error = InferError; + + fn try_from(output: &RequestOutput) -> Result { + if let Some(last) = output.outputs.last() { + if let Some(token_id) = last.token_ids.last() { + let token = Token { + id: *token_id, + text: last.text.clone(), + // logprob: last.logprobs[0], + logprob: 0.0f32, + special: false, + }; + + if !output.finished { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) + } else { + // TODO: Let's see how to request metrics + // let metrics = output + // .metrics + // .last() + // .expect("metrics should be set if token was unpacked"); + // + // debug!("Request: {} -> {metrics:?}", &output.request_id); + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: last.text.clone(), + generated_tokens: last.token_ids.len() as u32, + finish_reason: last + .finish_reason + .as_ref() + .map(|reason| match reason.as_str() { + "length" => FinishReason::Length, + _ => FinishReason::StopSequence, + }) + .unwrap(), + seed: None, + }, + // start: STARTUP_INSTANT + // .checked_sub(Duration::from_secs_f32(metrics.first_scheduled_time)) + // .unwrap_or_else(Instant::now), + // queued: STARTUP_INSTANT + // .checked_sub(Duration::from_secs_f32(metrics.arrival_time)) + // .unwrap_or_else(Instant::now), + start: Instant::now(), + queued: Instant::now(), + }) + } + } else { + Err(InferError::GenerationError("No token returned".to_string())) + } + } else { + Err(InferError::GenerationError("No token returned".to_string())) + } + } +} + +struct VllmRequestContext { + tokens: Arc>, + params: ValidParameters, + stopping_params: ValidStoppingParameters, + stream: UnboundedSender, +} + +pub struct VllmBackend { + waiting_requests: Sender, +} + +impl VllmBackend { + pub fn from_engine_args(args: EngineArgs) -> Result { + let engine = LlmEngine::from_engine_args(args)?; + let (sender, receiver) = unbounded(); + let _ = spawn(|| engine_background_loop(engine, receiver)); + Ok(Self { + waiting_requests: sender, + }) + } +} + +#[async_trait] +impl Backend for VllmBackend { + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + let (sender, receiver) = unbounded_channel(); + + // Send the query to the vLLM Engine + if let Some(input_ids) = request.input_ids { + debug!("Queuing new request"); + if let Err(err) = self.waiting_requests.send(VllmRequestContext { + tokens: Arc::clone(&input_ids), + params: request.parameters, + stopping_params: request.stopping_parameters, + stream: sender, + }) { + warn!("Waiting Requests queue has been closed: {err}") + } + }; + + Ok(UnboundedReceiverStream::new(receiver)) + } + + async fn health(&self, _current_health: bool) -> bool { + true + } +} + +fn engine_background_loop(mut engine: LlmEngine, waiting_requests: Receiver) { + info!("Starting vLLM engine background loop"); + static DURATION_100_MS: Duration = Duration::from_millis(100); + let mut in_flight_requests = HashMap::with_capacity(256); + 'outer: loop { + if !waiting_requests.is_empty() { + match waiting_requests.recv_timeout(DURATION_100_MS) { + Ok(context) => match engine.add_request( + &context.tokens, + &context.params, + &context.stopping_params, + ) { + Ok(request_id) => { + debug!("Successfully scheduled request {request_id}"); + in_flight_requests.insert(request_id.to_string(), context); + } + Err(err) => { + warn!("Failed to schedule new request: {err}"); + } + }, + Err(err) => match err { + RecvTimeoutError::Disconnected => break 'outer, + _ => {} // timeout all fine + }, + } + } + + // If there are tracked requests, let's pick the intermediate results + if !in_flight_requests.is_empty() { + match engine.step() { + Ok(outputs) => outputs.iter().for_each(|output| { + // Retrieve the context + { + let ctx = &in_flight_requests[&output.request_id]; + let result = InferStreamResponse::try_from(output); + + // We only need to check on Err meaning the channel is not open anymore, so abort the request + if let Err(_) = ctx.stream.send(result) { + debug!("Request {}'s channel dropped, aborting", &output.request_id); + in_flight_requests.remove(&output.request_id); + engine.abort_request(&output.request_id); + } + } + + // Drop the request if done + if output.finished { + in_flight_requests.remove(&output.request_id); + } + }), + Err(err) => { + error!("LLMEngine::step got an error: {err}"); + // TODO: Shall we exit from here? We can't link this to any particular user, + // it's Rust <> Python FFI which failed + } + } + } + + spin_loop(); + } + + info!("Shutting down vLLM engine background loop"); +} diff --git a/backends/vllm/src/engine.rs b/backends/vllm/src/engine.rs new file mode 100644 index 00000000..dcbff82f --- /dev/null +++ b/backends/vllm/src/engine.rs @@ -0,0 +1,249 @@ +use crate::errors::VllmBackendError; +use crate::{sampling_params, tokens_prompt, TryToPyObject}; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; +use pyo3::types::{IntoPyDict, PyDict, PyList, PyString}; +use text_generation_router::validation::{ValidParameters, ValidStoppingParameters}; +use tracing::{info, instrument}; +use uuid::Uuid; + +pub struct EngineArgs { + pub model: String, + pub pipeline_parallel_size: u32, + pub tensor_parallel_size: u32, +} + +impl IntoPyDict for EngineArgs { + fn into_py_dict_bound(self, py: Python<'_>) -> Bound<'_, PyDict> { + PyDict::from_sequence_bound( + PyList::new_bound( + py, + [ + ("model", self.model.into_py(py)), + ( + "pipeline_parallel_size", + self.pipeline_parallel_size.into_py(py), + ), + ( + "tensor_parallel_size", + self.tensor_parallel_size.into_py(py), + ), + ], + ) + .as_any(), + ) + .expect("Failed to create Python Dict from EngineArgs") + } +} + +static FINAL_OUTPUT_ONLY: GILOnceCell = GILOnceCell::new(); + +pub struct SamplingParams<'a> { + sampling_params: &'a ValidParameters, + stopping_params: &'a ValidStoppingParameters, +} + +impl TryToPyObject for SamplingParams<'_> { + fn try_to_object(&self, py: Python<'_>) -> Result { + let py_sampling_params_class = sampling_params(py); + + let kwargs = PyDict::from_sequence_bound(&PyList::new_bound( + py, + [ + (intern!(py, "output_kind"), 2.into_py(py)), + (intern!(py, "logprobs"), 1.into_py(py)), + (intern!(py, "n"), 1.into_py(py)), + (intern!(py, "seed"), self.sampling_params.seed.into_py(py)), + (intern!(py, "top_k"), self.sampling_params.top_k.into_py(py)), + (intern!(py, "top_p"), self.sampling_params.top_p.into_py(py)), + ( + intern!(py, "temperature"), + self.sampling_params.temperature.into_py(py), + ), + ( + intern!(py, "frequency_penalty"), + self.sampling_params.frequency_penalty.into_py(py), + ), + ( + intern!(py, "repetition_penalty"), + self.sampling_params.repetition_penalty.into_py(py), + ), + ( + intern!(py, "ignore_eos"), + self.stopping_params.ignore_eos_token.into_py(py), + ), + ( + intern!(py, "max_tokens"), + self.stopping_params.max_new_tokens.into_py(py), + ), + ( + intern!(py, "stop"), + PyList::new_bound(py, self.stopping_params.stop_sequences.iter()).into(), + ), + ], + )); + + Ok(py_sampling_params_class + .call_method_bound(py, "from_optional", (), Some(&kwargs?))? + .to_object(py)) + } +} + +#[derive(Debug)] +pub(crate) struct CompletionOutput { + pub token_ids: Vec, // TODO: TinyVec? + pub text: String, // TODO: SmallString? + // pub logprobs: Vec, // TODO: TinyVec? + pub finish_reason: Option, // lora_request: LATER + pub index: usize, +} + +#[derive(Debug, Copy, Clone)] +pub(crate) struct RequestMetrics { + pub arrival_time: f32, + pub first_scheduled_time: f32, + pub first_token_time: f32, + pub time_in_queue: f32, +} + +impl<'py> FromPyObject<'py> for RequestMetrics { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + arrival_time: ob.getattr(intern!(py, "arrival_time"))?.extract()?, + first_scheduled_time: ob.getattr(intern!(py, "first_scheduled_time"))?.extract()?, + first_token_time: ob.getattr(intern!(py, "first_token_time"))?.extract()?, + time_in_queue: ob.getattr(intern!(py, "time_in_queue"))?.extract()?, + }) + } +} + +#[derive(Debug)] +pub(crate) struct RequestOutput { + pub outputs: Vec, + // pub metrics: Vec, + pub request_id: String, + pub finished: bool, +} + +impl<'py> FromPyObject<'py> for CompletionOutput { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + index: ob.getattr(intern!(py, "index"))?.extract()?, + text: ob.getattr(intern!(py, "text"))?.extract()?, + token_ids: ob.getattr(intern!(py, "token_ids"))?.extract()?, + // logprobs: ob.getattr(intern!(py, "logprobs"))?.extract()?, + finish_reason: ob.getattr(intern!(py, "finish_reason"))?.extract()?, + }) + } +} + +impl<'py> FromPyObject<'py> for RequestOutput { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + Ok(Self { + request_id: ob.getattr(intern!(py, "request_id"))?.extract()?, + outputs: ob.getattr(intern!(py, "outputs"))?.extract()?, + finished: ob.getattr(intern!(py, "finished"))?.extract()?, + // metrics: ob.getattr(intern!(py, "metrics"))?.extract()?, + }) + } +} + +pub struct LlmEngine { + engine: PyObject, +} + +impl LlmEngine { + fn py_from_engine_args(args: EngineArgs) -> PyResult { + Python::with_gil(|py| { + // Create the EngineArgs from Rust + // from vllm.engine.arg_util import EngineArgs + // engine_args = EngineArgs(**args) + let py_engine_args_mod = PyModule::import_bound(py, "vllm.engine.arg_utils")?; + let py_engine_args_class = py_engine_args_mod.getattr("EngineArgs")?; + let py_engine_args = + py_engine_args_class.call((), Some(&args.into_py_dict_bound(py)))?; + + // Next create the LLMEngine from the EngineArgs + // from vllm.engine.llm_engine import LLMEngine + // engine = LLMEngine.from_engine_args(engine_args) + let py_engine_llm_mod = PyModule::import_bound(py, "vllm.v1.engine.llm_engine")?; + let py_engine_llm_class = py_engine_llm_mod.getattr("LLMEngine")?; + py_engine_llm_class + .call_method("from_engine_args", (py_engine_args,), None)? + .extract() + }) + } + + fn py_add_request( + &self, + request_id: &str, + prompt: &[u32], + sampling_params: SamplingParams, + ) -> Result<(), VllmBackendError> { + Python::with_gil(|py| { + // Create vllm.Tokens + let kwargs = [(intern!(py, "prompt_token_ids"), prompt)].into_py_dict_bound(py); + let py_tokens_prompt_class = tokens_prompt(py); + let py_tokens_prompt = py_tokens_prompt_class.call_bound(py, (), Some(&kwargs))?; + let py_sampling_params = sampling_params.try_to_object(py)?; + + self.engine.call_method1( + py, + intern!(py, "add_request"), + ( + PyString::new_bound(py, request_id), + py_tokens_prompt, + py_sampling_params, + ), + )?; + + self.engine.call_method0(py, intern!(py, "step")) + })?; + + Ok(()) + } + + fn py_step(&self) -> Result, VllmBackendError> { + Ok(Python::with_gil(|py| { + self.engine + .call_method0(py, intern!(py, "step"))? + .extract::>(py) + })?) + } + + pub fn from_engine_args(args: EngineArgs) -> Result { + let engine = Self::py_from_engine_args(args)?; + + Ok(Self { engine }) + } + + #[instrument(skip_all)] + pub fn add_request( + &self, + prompt: &[u32], + sampling_params: &ValidParameters, + stopping_params: &ValidStoppingParameters, + ) -> Result { + let request_id = Uuid::new_v4(); + let sampling_params = SamplingParams { + sampling_params, + stopping_params, + }; + self.py_add_request(&request_id.to_string(), prompt, sampling_params)?; + + info!("Submitted new request: {request_id}"); + Ok(request_id) + } + + #[instrument(skip_all)] + pub fn abort_request(&self, _request_id: &str) {} + + #[instrument(skip_all)] + pub fn step(&mut self) -> Result, VllmBackendError> { + self.py_step() + } +} diff --git a/backends/vllm/src/errors.rs b/backends/vllm/src/errors.rs new file mode 100644 index 00000000..1b03f5a4 --- /dev/null +++ b/backends/vllm/src/errors.rs @@ -0,0 +1,31 @@ +use pyo3::PyErr; +use text_generation_router::infer::InferError; +use text_generation_router::server::WebServerError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum VllmBackendError { + #[error("[Python] {0}")] + Python(PyErr), + + #[error("[WebServer] {0}")] + WebServer(WebServerError), +} + +impl From for VllmBackendError { + fn from(value: PyErr) -> Self { + Self::Python(value) + } +} + +impl From for VllmBackendError { + fn from(value: WebServerError) -> Self { + Self::WebServer(value) + } +} + +impl From for InferError { + fn from(value: VllmBackendError) -> Self { + InferError::GenerationError(value.to_string()) + } +} diff --git a/backends/vllm/src/lib.rs b/backends/vllm/src/lib.rs new file mode 100644 index 00000000..d0b44565 --- /dev/null +++ b/backends/vllm/src/lib.rs @@ -0,0 +1,43 @@ +mod backend; +mod engine; +mod errors; + +pub use backend::VllmBackend; +pub use engine::{EngineArgs, LlmEngine}; +pub use errors::VllmBackendError; +use pyo3::prelude::PyAnyMethods; +use pyo3::sync::GILOnceCell; +use pyo3::types::PyModule; +use pyo3::{Py, PyAny, PyErr, PyObject, Python}; +use tokio::time::Instant; + +pub(crate) static STARTUP_INSTANT: Instant = Instant::now(); + +static PY_TOKENS_PROMPT_CLASS: GILOnceCell> = GILOnceCell::new(); +static PY_SAMPLING_PARAMS_CLASS: GILOnceCell> = GILOnceCell::new(); + +#[inline] +pub(crate) fn tokens_prompt(py: Python) -> &Py { + PY_TOKENS_PROMPT_CLASS.get_or_init(py, || { + PyModule::import_bound(py, "vllm.inputs") + .expect("Failed to import vllm.inputs") + .getattr("TokensPrompt") + .expect("Failed to import vllm.inputs.TokensPrompt") + .unbind() + }) +} + +#[inline] +pub(crate) fn sampling_params(py: Python) -> &Py { + PY_SAMPLING_PARAMS_CLASS.get_or_init(py, || { + PyModule::import_bound(py, "vllm") + .expect("Failed to import vllm") + .getattr("SamplingParams") + .expect("Failed to import vllm.SamplingParams") + .unbind() + }) +} + +pub(crate) trait TryToPyObject { + fn try_to_object(&self, py: Python<'_>) -> Result; +} diff --git a/backends/vllm/src/main.rs b/backends/vllm/src/main.rs new file mode 100644 index 00000000..55f47871 --- /dev/null +++ b/backends/vllm/src/main.rs @@ -0,0 +1,108 @@ +use clap::Parser; +use text_generation_backends_vllm::{EngineArgs, VllmBackend, VllmBackendError}; +use text_generation_router::{server, usage_stats}; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(long, env)] + max_input_tokens: Option, + #[clap(long, env)] + max_total_tokens: Option, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, + #[clap(default_value = "2000000", long, env)] + payload_limit: usize, +} + +impl Into for &Args { + fn into(self) -> EngineArgs { + EngineArgs { + model: self.tokenizer_name.clone(), + pipeline_parallel_size: 1, // TODO + tensor_parallel_size: 1, // TODO + } + } +} + +#[tokio::main] +async fn main() -> Result<(), VllmBackendError> { + tracing_subscriber::fmt::init(); + + let args = Args::parse(); + let backend = VllmBackend::from_engine_args((&args).into())?; + + server::run( + backend, + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens.unwrap_or(1024), // TODO + args.max_total_tokens.unwrap_or(2048), // TODO + args.validation_workers, + args.api_key, + args.tokenizer_name, + args.tokenizer_config_path, + args.revision, + args.trust_remote_code, + args.hostname, + args.port, + args.cors_allow_origin, + false, + None, + None, + args.disable_grammar_support, + args.max_batch_size.unwrap_or(16), + args.usage_stats, + args.payload_limit, + ) + .await?; + Ok(()) +} diff --git a/router/Cargo.toml b/router/Cargo.toml index e4d0179a..d108e9dc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -31,11 +31,11 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } @@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true } minijinja-contrib = { workspace = true } @@ -57,9 +57,9 @@ image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = [ - "v4", - "fast-rng", - "macro-diagnostics", + "v4", + "fast-rng", + "macro-diagnostics", ] } csv = "1.3.0" ureq = "=2.9"