mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
Initial TRTLLM backend structure
This commit is contained in:
parent
230f2a415a
commit
da9456d00a
@ -5,7 +5,7 @@ members = [
|
||||
# "backends/client",
|
||||
"backends/grpc-metadata",
|
||||
"launcher"
|
||||
]
|
||||
, "backends/trtllm"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
|
22
backends/trtllm/Cargo.toml
Normal file
22
backends/trtllm/Cargo.toml
Normal file
@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "trtllm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
build = "build.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
async-stream = "0.3"
|
||||
clap = { version = "4.5", features = ["derive", "env"] }
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
thiserror = "1.0"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = "1.0"
|
||||
cmake = "0.1"
|
||||
git2 = "0.19"
|
0
backends/trtllm/build.rs
Normal file
0
backends/trtllm/build.rs
Normal file
0
backends/trtllm/csrc/CMakeLists.txt
Normal file
0
backends/trtllm/csrc/CMakeLists.txt
Normal file
32
backends/trtllm/src/backend.rs
Normal file
32
backends/trtllm/src/backend.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
|
||||
pub struct TensorRTBackend {}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TensorRTBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
_request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
let (_sender, receiver) = mpsc::unbounded_channel();
|
||||
|
||||
Ok(UnboundedReceiverStream::new(receiver))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TensorRTBackend {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "TensorRT-LLM Backend")
|
||||
}
|
||||
}
|
4
backends/trtllm/src/errors.rs
Normal file
4
backends/trtllm/src/errors.rs
Normal file
@ -0,0 +1,4 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorRTError {}
|
5
backends/trtllm/src/lib.rs
Normal file
5
backends/trtllm/src/lib.rs
Normal file
@ -0,0 +1,5 @@
|
||||
mod backend;
|
||||
mod errors;
|
||||
|
||||
pub use backend::TensorRTBackend;
|
||||
pub use errors::*;
|
56
backends/trtllm/src/main.rs
Normal file
56
backends/trtllm/src/main.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use clap::Parser;
|
||||
use thiserror::Error;
|
||||
use text_generation_router::server;
|
||||
use trtllm::TensorRTError;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] TensorRTError),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user