From da9456d00ab4fff8bf24a61ae9c7387f33e6fe2e Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 26 Jun 2024 23:21:43 +0200 Subject: [PATCH] Initial TRTLLM backend structure --- Cargo.toml | 2 +- backends/trtllm/Cargo.toml | 22 ++++++++++++ backends/trtllm/build.rs | 0 backends/trtllm/csrc/CMakeLists.txt | 0 backends/trtllm/src/backend.rs | 32 +++++++++++++++++ backends/trtllm/src/errors.rs | 4 +++ backends/trtllm/src/lib.rs | 5 +++ backends/trtllm/src/main.rs | 56 +++++++++++++++++++++++++++++ 8 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 backends/trtllm/Cargo.toml create mode 100644 backends/trtllm/build.rs create mode 100644 backends/trtllm/csrc/CMakeLists.txt create mode 100644 backends/trtllm/src/backend.rs create mode 100644 backends/trtllm/src/errors.rs create mode 100644 backends/trtllm/src/lib.rs create mode 100644 backends/trtllm/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 28ded514f..e91f26096 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ # "backends/client", "backends/grpc-metadata", "launcher" -] +, "backends/trtllm"] resolver = "2" [workspace.package] diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml new file mode 100644 index 000000000..814a0a135 --- /dev/null +++ b/backends/trtllm/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "trtllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +build = "build.rs" + +[dependencies] +async-trait = "0.1" +async-stream = "0.3" +clap = { version = "4.5", features = ["derive", "env"] } +text-generation-router = { path = "../../router" } +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" +thiserror = "1.0" + +[build-dependencies] +anyhow = "1.0" +cmake = "0.1" +git2 = "0.19" diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs new file mode 100644 index 000000000..e69de29bb diff --git a/backends/trtllm/csrc/CMakeLists.txt b/backends/trtllm/csrc/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs new file mode 100644 index 000000000..d7c464478 --- /dev/null +++ b/backends/trtllm/src/backend.rs @@ -0,0 +1,32 @@ +use std::fmt::{Display, Formatter}; + +use async_trait::async_trait; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; + +use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; + +pub struct TensorRTBackend {} + +#[async_trait] +impl Backend for TensorRTBackend { + fn schedule( + &self, + _request: ValidGenerateRequest, + ) -> Result>, InferError> { + let (_sender, receiver) = mpsc::unbounded_channel(); + + Ok(UnboundedReceiverStream::new(receiver)) + } + + async fn health(&self, current_health: bool) -> bool { + todo!() + } +} + +impl Display for TensorRTBackend { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "TensorRT-LLM Backend") + } +} diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs new file mode 100644 index 000000000..9a2fc539e --- /dev/null +++ b/backends/trtllm/src/errors.rs @@ -0,0 +1,4 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum TensorRTError {} diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs new file mode 100644 index 000000000..b0f61113d --- /dev/null +++ b/backends/trtllm/src/lib.rs @@ -0,0 +1,5 @@ +mod backend; +mod errors; + +pub use backend::TensorRTBackend; +pub use errors::*; \ No newline at end of file diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs new file mode 100644 index 000000000..2e6003743 --- /dev/null +++ b/backends/trtllm/src/main.rs @@ -0,0 +1,56 @@ +use clap::Parser; +use thiserror::Error; +use text_generation_router::server; +use trtllm::TensorRTError; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] TensorRTError), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} + + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + + Ok(()) +} \ No newline at end of file