mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
(looper) new looper initial implementation
This commit is contained in:
parent
5f7c0b67c3
commit
fb759bdd2a
@ -4,6 +4,8 @@ use text_generation_router::server;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorRtLlmBackendError {
|
||||
#[error("TensorRT-LLM Runtime error: {0}")]
|
||||
Runtime(String),
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(String),
|
||||
#[error("Argument validation error: {0}")]
|
||||
|
182
backends/trtllm/src/looper.rs
Normal file
182
backends/trtllm/src/looper.rs
Normal file
@ -0,0 +1,182 @@
|
||||
use std::hint;
|
||||
use std::ops::Deref;
|
||||
use std::path::Path;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use hashbrown::HashMap;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{error, info, Level, span};
|
||||
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::infer::InferError::GenerationError;
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||
|
||||
// Value used to poll the state of the generation stream
|
||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||
|
||||
// It's safe to send the backend between threads
|
||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
fn executor_status_poller(
|
||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||
) {
|
||||
// Track the tuple (request_id, stream) for each request
|
||||
let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128);
|
||||
|
||||
// TODO: Does it need a spin-loop?
|
||||
loop {
|
||||
span!(Level::DEBUG, "in-flight submit").in_scope(|| {
|
||||
// Is there any request pending to be scheduled?
|
||||
let awaiting_requests = waiting_requests.len();
|
||||
if awaiting_requests > 0 {
|
||||
// Retrieve all the requests
|
||||
let mut requests = Vec::with_capacity(awaiting_requests);
|
||||
let _ = waiting_requests.recv_many(&mut requests, awaiting_requests);
|
||||
|
||||
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||
for ctx in requests {
|
||||
let request = &ctx.request;
|
||||
let generation_params = &request.parameters;
|
||||
let stopping_params = &request.stopping_parameters;
|
||||
|
||||
// Submit to the TensorRT-LLM executor for scheduling
|
||||
match backend.pin_mut().submit(
|
||||
&vec![],
|
||||
stopping_params.max_new_tokens,
|
||||
generation_params.top_k as i32,
|
||||
generation_params.top_p,
|
||||
generation_params.temperature,
|
||||
generation_params.repetition_penalty,
|
||||
generation_params.frequency_penalty,
|
||||
generation_params.seed,
|
||||
) {
|
||||
Ok(request_id) => {
|
||||
// Insert the context linked to the generated request id in the tracker
|
||||
in_flights.insert(request_id, ctx);
|
||||
}
|
||||
Err(e) => {
|
||||
// Return to the caller
|
||||
let what = Err(InferError::SchedulingError(e.to_string()));
|
||||
if let Err(e) = ctx.streamer.send(what) {
|
||||
error!("Failed to send back through the channel: {}", e);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
span!(Level::DEBUG, "in-flight poll").in_scope(|| {
|
||||
if backend.num_responses_ready() > 0 {
|
||||
match backend.pin_mut().pull_tokens() {
|
||||
Ok(responses) => {
|
||||
for step in responses.deref() {
|
||||
let request_id = step.request_id;
|
||||
match in_flights.get(&request_id) {
|
||||
Some(ctx) => {
|
||||
info!("New token for {} -> {}", request_id, step.token_id);
|
||||
|
||||
if step.is_final {
|
||||
let _ = in_flights.remove(&step.request_id);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
error!("Got step for untracked request {}", request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Failed to retrieve tokens from the executor: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Hint the CPU we are spin-locking
|
||||
hint::spin_loop();
|
||||
}
|
||||
}
|
||||
|
||||
struct GenerationContext {
|
||||
request: ValidGenerateRequest,
|
||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
}
|
||||
|
||||
pub struct TensorRtLlmBackendV2 {
|
||||
tokenizer: Tokenizer,
|
||||
looper: JoinHandle<()>,
|
||||
queue: UnboundedSender<GenerationContext>,
|
||||
}
|
||||
|
||||
impl TensorRtLlmBackendV2 {
|
||||
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
||||
tokenizer: Tokenizer,
|
||||
engine_folder: P,
|
||||
executor_worker_path: PP,
|
||||
) -> Result<Self, TensorRtLlmBackendError> {
|
||||
// Retrieve paths as &str for the backend creation
|
||||
let engine_folder = engine_folder.as_ref();
|
||||
let executor_worker_path = executor_worker_path.as_ref();
|
||||
|
||||
let engine_folder = String::from(
|
||||
engine_folder
|
||||
.to_str()
|
||||
.expect("Failed to convert engine_folder to valid UTF-8"),
|
||||
);
|
||||
|
||||
let executor_worker_path = String::from(
|
||||
executor_worker_path
|
||||
.to_str()
|
||||
.expect("Failed to convert executor_worker_path to valid UTF-8"),
|
||||
);
|
||||
|
||||
// Allocate the IPC layer to communicate with the backend
|
||||
let (requests_sender, requests_receiver) = unbounded_channel::<GenerationContext>();
|
||||
|
||||
// Create the FFI backend
|
||||
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
|
||||
.map_err(|e| TensorRtLlmBackendError::Runtime(e.what().to_string()))?;
|
||||
|
||||
// Looper is responsible for scheduling and pulling requests state at regular interval
|
||||
let looper =
|
||||
tokio::task::spawn_blocking(move || executor_status_poller(backend, requests_receiver));
|
||||
|
||||
Ok(TensorRtLlmBackendV2 {
|
||||
tokenizer,
|
||||
looper,
|
||||
queue: requests_sender,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TensorRtLlmBackendV2 {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||
match self.queue.send(GenerationContext { request, streamer }) {
|
||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||
Err(_) => Err(GenerationError(
|
||||
"Failed to submit request to the backend".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
current_health & !self.looper.is_finished()
|
||||
}
|
||||
}
|
@ -1,10 +1,17 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::info;
|
||||
|
||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||
use text_generation_router::{HubTokenizerConfig, server};
|
||||
use text_generation_router::server::{
|
||||
create_post_processor, get_base_tokenizer, get_hub_model_info,
|
||||
};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
@ -58,6 +65,147 @@ struct Args {
|
||||
executor_worker: PathBuf,
|
||||
}
|
||||
|
||||
async fn get_tokenizer(
|
||||
tokenizer_name: &str,
|
||||
tokenizer_config_path: Option<&str>,
|
||||
revision: Option<&str>,
|
||||
) -> Option<Tokenizer> {
|
||||
// Parse Huggingface hub token
|
||||
let authorization_token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.ok();
|
||||
|
||||
// Tokenizer instance
|
||||
let local_path = Path::new(tokenizer_name);
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
let mut builder = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_token(authorization_token);
|
||||
|
||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||
builder = builder.with_cache_dir(cache_dir.into());
|
||||
}
|
||||
|
||||
builder
|
||||
};
|
||||
|
||||
// Decide if we need to use the API based on the revision and local path
|
||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||
|
||||
// Initialize API if needed
|
||||
#[derive(Clone)]
|
||||
enum Type {
|
||||
Api(Api),
|
||||
Cache(Cache),
|
||||
None,
|
||||
}
|
||||
let api = if use_api {
|
||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
||||
.map_err(|_| ())
|
||||
.map(|cache_dir| Cache::new(cache_dir.into()))
|
||||
.unwrap_or_else(|_| Cache::default());
|
||||
tracing::warn!("Offline mode active using cache defaults");
|
||||
Type::Cache(cache)
|
||||
} else {
|
||||
tracing::info!("Using the Hugging Face API");
|
||||
match api_builder().build() {
|
||||
Ok(api) => Type::Api(api),
|
||||
Err(_) => {
|
||||
tracing::warn!("Unable to build the Hugging Face API");
|
||||
Type::None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Type::None
|
||||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
) = match api {
|
||||
Type::None => (
|
||||
Some(local_path.join("tokenizer.json")),
|
||||
Some(local_path.join("config.json")),
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("preprocessor_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
),
|
||||
Type::Api(api) => {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.unwrap_or_else(|| "main").to_string(),
|
||||
));
|
||||
|
||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
let repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main").to_string(),
|
||||
));
|
||||
(
|
||||
repo.get("tokenizer.json"),
|
||||
repo.get("config.json"),
|
||||
repo.get("tokenizer_config.json"),
|
||||
repo.get("preprocessor_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
{
|
||||
HubTokenizerConfig::from_file(filename)
|
||||
} else {
|
||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
};
|
||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
tokenizer_filename.and_then(|filename| {
|
||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||
if let Some(tokenizer) = &mut tokenizer {
|
||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||
tokenizer.with_post_processor(post_processor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenizer
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
// Get args
|
||||
@ -124,18 +272,21 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
)));
|
||||
}
|
||||
|
||||
// Run server
|
||||
let tokenizer = Tokenizer::from_pretrained(
|
||||
tokenizer_name.clone(),
|
||||
Some(FromPretrainedParameters {
|
||||
revision: revision.clone().unwrap_or(String::from("main")),
|
||||
user_agent: HashMap::new(),
|
||||
auth_token,
|
||||
}),
|
||||
// Create the backend
|
||||
let tokenizer = get_tokenizer(
|
||||
&tokenizer_name,
|
||||
tokenizer_config_path.as_deref(),
|
||||
revision.as_deref(),
|
||||
)
|
||||
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||
.await
|
||||
.expect("Failed to retrieve tokenizer implementation");
|
||||
|
||||
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
||||
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
|
||||
let backend = TensorRtLlmBackendV2::new(tokenizer, model_id, executor_worker)?;
|
||||
|
||||
info!("Successfully created backend");
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
|
Loading…
Reference in New Issue
Block a user