diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs index a672d2a4..8ec6e1af 100644 --- a/backends/trtllm/src/errors.rs +++ b/backends/trtllm/src/errors.rs @@ -4,6 +4,8 @@ use text_generation_router::server; #[derive(Debug, Error)] pub enum TensorRtLlmBackendError { + #[error("TensorRT-LLM Runtime error: {0}")] + Runtime(String), #[error("Tokenizer error: {0}")] Tokenizer(String), #[error("Argument validation error: {0}")] diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs new file mode 100644 index 00000000..29866c2f --- /dev/null +++ b/backends/trtllm/src/looper.rs @@ -0,0 +1,182 @@ +use std::hint; +use std::ops::Deref; +use std::path::Path; +use std::sync::OnceLock; + +use async_trait::async_trait; +use cxx::UniquePtr; +use hashbrown::HashMap; +use tokenizers::Tokenizer; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{error, info, Level, span}; + +use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::infer::InferError::GenerationError; +use text_generation_router::validation::ValidGenerateRequest; + +use crate::errors::TensorRtLlmBackendError; +use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; + +// Value used to poll the state of the generation stream +static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); + +// It's safe to send the backend between threads +unsafe impl Send for TensorRtLlmBackendImpl {} + +type InferResult = Result; + +fn executor_status_poller( + mut backend: UniquePtr, + mut waiting_requests: UnboundedReceiver, +) { + // Track the tuple (request_id, stream) for each request + let mut in_flights = HashMap::::with_capacity(128); + + // TODO: Does it need a spin-loop? + loop { + span!(Level::DEBUG, "in-flight submit").in_scope(|| { + // Is there any request pending to be scheduled? + let awaiting_requests = waiting_requests.len(); + if awaiting_requests > 0 { + // Retrieve all the requests + let mut requests = Vec::with_capacity(awaiting_requests); + let _ = waiting_requests.recv_many(&mut requests, awaiting_requests); + + // Submit all the request to the executor and move the context to the in-flight tracker + for ctx in requests { + let request = &ctx.request; + let generation_params = &request.parameters; + let stopping_params = &request.stopping_parameters; + + // Submit to the TensorRT-LLM executor for scheduling + match backend.pin_mut().submit( + &vec![], + stopping_params.max_new_tokens, + generation_params.top_k as i32, + generation_params.top_p, + generation_params.temperature, + generation_params.repetition_penalty, + generation_params.frequency_penalty, + generation_params.seed, + ) { + Ok(request_id) => { + // Insert the context linked to the generated request id in the tracker + in_flights.insert(request_id, ctx); + } + Err(e) => { + // Return to the caller + let what = Err(InferError::SchedulingError(e.to_string())); + if let Err(e) = ctx.streamer.send(what) { + error!("Failed to send back through the channel: {}", e); + } + } + }; + } + } + }); + + span!(Level::DEBUG, "in-flight poll").in_scope(|| { + if backend.num_responses_ready() > 0 { + match backend.pin_mut().pull_tokens() { + Ok(responses) => { + for step in responses.deref() { + let request_id = step.request_id; + match in_flights.get(&request_id) { + Some(ctx) => { + info!("New token for {} -> {}", request_id, step.token_id); + + if step.is_final { + let _ = in_flights.remove(&step.request_id); + } + } + None => { + error!("Got step for untracked request {}", request_id); + } + } + } + } + Err(err) => { + error!("Failed to retrieve tokens from the executor: {}", err); + } + } + } + }); + + // Hint the CPU we are spin-locking + hint::spin_loop(); + } +} + +struct GenerationContext { + request: ValidGenerateRequest, + streamer: UnboundedSender>, +} + +pub struct TensorRtLlmBackendV2 { + tokenizer: Tokenizer, + looper: JoinHandle<()>, + queue: UnboundedSender, +} + +impl TensorRtLlmBackendV2 { + pub fn new + Send, PP: AsRef + Send>( + tokenizer: Tokenizer, + engine_folder: P, + executor_worker_path: PP, + ) -> Result { + // Retrieve paths as &str for the backend creation + let engine_folder = engine_folder.as_ref(); + let executor_worker_path = executor_worker_path.as_ref(); + + let engine_folder = String::from( + engine_folder + .to_str() + .expect("Failed to convert engine_folder to valid UTF-8"), + ); + + let executor_worker_path = String::from( + executor_worker_path + .to_str() + .expect("Failed to convert executor_worker_path to valid UTF-8"), + ); + + // Allocate the IPC layer to communicate with the backend + let (requests_sender, requests_receiver) = unbounded_channel::(); + + // Create the FFI backend + let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path) + .map_err(|e| TensorRtLlmBackendError::Runtime(e.what().to_string()))?; + + // Looper is responsible for scheduling and pulling requests state at regular interval + let looper = + tokio::task::spawn_blocking(move || executor_status_poller(backend, requests_receiver)); + + Ok(TensorRtLlmBackendV2 { + tokenizer, + looper, + queue: requests_sender, + }) + } +} + +#[async_trait] +impl Backend for TensorRtLlmBackendV2 { + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + let (streamer, receiver) = unbounded_channel::>(); + match self.queue.send(GenerationContext { request, streamer }) { + Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), + Err(_) => Err(GenerationError( + "Failed to submit request to the backend".into(), + )), + } + } + + async fn health(&self, current_health: bool) -> bool { + current_health & !self.looper.is_finished() + } +} diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index e0ba46c7..15f40f5a 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -1,10 +1,17 @@ +use std::path::{Path, PathBuf}; + use clap::Parser; -use std::collections::HashMap; -use std::path::PathBuf; +use hf_hub::{Cache, Repo, RepoType}; +use hf_hub::api::tokio::{Api, ApiBuilder}; +use tokenizers::Tokenizer; +use tracing::info; + use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; -use text_generation_backends_trtllm::TensorRtLlmBackend; -use text_generation_router::server; -use tokenizers::{FromPretrainedParameters, Tokenizer}; +use text_generation_backends_trtllm::TensorRtLlmBackendV2; +use text_generation_router::{HubTokenizerConfig, server}; +use text_generation_router::server::{ + create_post_processor, get_base_tokenizer, get_hub_model_info, +}; /// App Configuration #[derive(Parser, Debug)] @@ -58,6 +65,147 @@ struct Args { executor_worker: PathBuf, } +async fn get_tokenizer( + tokenizer_name: &str, + tokenizer_config_path: Option<&str>, + revision: Option<&str>, +) -> Option { + // Parse Huggingface hub token + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); + + // Tokenizer instance + let local_path = Path::new(tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, + } + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None + } + } + } + } else { + Type::None + }; + + // Load tokenizer and model info + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + ) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("preprocessor_config.json")), + Some(local_path.join("processor_config.json")), + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or_else(|| "main").to_string(), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main").to_string(), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + repo.get("preprocessor_config.json"), + repo.get("processor_config.json"), + ) + } + }; + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer { + if let Some(class) = &tokenizer_config.tokenizer_class { + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ + if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { + tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + tokenizer.with_post_processor(post_processor); + } + } + } + } + tokenizer + }) +} + #[tokio::main] async fn main() -> Result<(), TensorRtLlmBackendError> { // Get args @@ -124,18 +272,21 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { ))); } - // Run server - let tokenizer = Tokenizer::from_pretrained( - tokenizer_name.clone(), - Some(FromPretrainedParameters { - revision: revision.clone().unwrap_or(String::from("main")), - user_agent: HashMap::new(), - auth_token, - }), + // Create the backend + let tokenizer = get_tokenizer( + &tokenizer_name, + tokenizer_config_path.as_deref(), + revision.as_deref(), ) - .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + .await + .expect("Failed to retrieve tokenizer implementation"); - let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; + info!("Successfully retrieved tokenizer {}", &tokenizer_name); + let backend = TensorRtLlmBackendV2::new(tokenizer, model_id, executor_worker)?; + + info!("Successfully created backend"); + + // Run server server::run( backend, max_concurrent_requests,