Initial TRTLLM backend structure

This commit is contained in:
Morgan Funtowicz 2024-06-26 23:21:43 +02:00
parent 230f2a415a
commit da9456d00a
8 changed files with 120 additions and 1 deletions

View File

@ -5,7 +5,7 @@ members = [
# "backends/client", # "backends/client",
"backends/grpc-metadata", "backends/grpc-metadata",
"launcher" "launcher"
] , "backends/trtllm"]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]

View File

@ -0,0 +1,22 @@
[package]
name = "trtllm"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
build = "build.rs"
[dependencies]
async-trait = "0.1"
async-stream = "0.3"
clap = { version = "4.5", features = ["derive", "env"] }
text-generation-router = { path = "../../router" }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
thiserror = "1.0"
[build-dependencies]
anyhow = "1.0"
cmake = "0.1"
git2 = "0.19"

0
backends/trtllm/build.rs Normal file
View File

View File

View File

@ -0,0 +1,32 @@
use std::fmt::{Display, Formatter};
use async_trait::async_trait;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
pub struct TensorRTBackend {}
#[async_trait]
impl Backend for TensorRTBackend {
fn schedule(
&self,
_request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
let (_sender, receiver) = mpsc::unbounded_channel();
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, current_health: bool) -> bool {
todo!()
}
}
impl Display for TensorRTBackend {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "TensorRT-LLM Backend")
}
}

View File

@ -0,0 +1,4 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum TensorRTError {}

View File

@ -0,0 +1,5 @@
mod backend;
mod errors;
pub use backend::TensorRTBackend;
pub use errors::*;

View File

@ -0,0 +1,56 @@
use clap::Parser;
use thiserror::Error;
use text_generation_router::server;
use trtllm::TensorRTError;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Backend failed: {0}")]
Backend(#[from] TensorRTError),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
Ok(())
}