[TRTLLM] Expose finish reason (#2841)

* feat(trtllm): expose finish reason to Rust

* misc(llamacpp): fix typo

* misc(backend): update deps
This commit is contained in:
Funtowicz Morgan 2025-01-23 16:48:26 +01:00 committed by GitHub
parent 4e172028aa
commit 0a89902663
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 36 deletions

View File

@ -30,7 +30,7 @@ option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF) option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located") set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")

View File

@ -7,20 +7,16 @@ homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
#async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
hashbrown = "0.14" hashbrown = "0.15"
hf-hub = { workspace = true } hf-hub = { workspace = true }
#log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.17"
thiserror = "1.0.63" thiserror = "1.0.63"
tracing = "0.1" tracing = "0.1"
#tracing-opentelemetry = "0.25"
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
pyo3 = { workspace = true } pyo3 = { workspace = true }
[build-dependencies] [build-dependencies]

View File

@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
}; };
use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token}; use text_generation_router::Token;
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl}; use crate::ffi::{
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
};
use crate::utils::first_line; use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>; type InferResult<T> = Result<T, InferError>;
@ -40,6 +42,7 @@ struct DecodedToken {
id: u32, id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
finish_reason: FinishReason,
} }
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
id: step.token_id, id: step.token_id,
log_prob: step.log_prob, log_prob: step.log_prob,
is_final: step.is_final, is_final: step.is_final,
finish_reason: step.finish_reason,
}) })
} else { } else {
Err(GenerationError(step.error_msg.clone())) Err(GenerationError(step.error_msg.clone()))
@ -192,7 +196,7 @@ fn post_process_decoded_token(
let generated_text = GeneratedText { let generated_text = GeneratedText {
text: text.unwrap(), text: text.unwrap(),
generated_tokens: ctx.tokens.len() as u32, generated_tokens: ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason finish_reason: decoded_token.finish_reason.into(),
seed: None, seed: None,
}; };

View File

@ -67,11 +67,7 @@ struct Args {
payload_limit: usize, payload_limit: usize,
} }
async fn get_tokenizer( async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
tokenizer_name: &str,
_tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token // Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN") let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
@ -182,19 +178,6 @@ async fn get_tokenizer(
} }
}; };
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
// {
// HubTokenizerConfig::from_file(filename)
// } else {
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
// };
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
// tracing::warn!("Could not find tokenizer config locally and no API specified");
// HubTokenizerConfig::default()
// });
let tokenizer: Tokenizer = { let tokenizer: Tokenizer = {
use pyo3::prelude::*; use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> { pyo3::Python::with_gil(|py| -> PyResult<()> {
@ -292,13 +275,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
} }
// Create the backend // Create the backend
match get_tokenizer( match get_tokenizer(&tokenizer_name, revision.as_deref())
&tokenizer_name, .await
tokenizer_config_path.as_deref(), .expect("Failed to retrieve tokenizer implementation")
revision.as_deref(),
)
.await
.expect("Failed to retrieve tokenizer implementation")
{ {
Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer( Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
"Failed to retrieve Rust based tokenizer".to_string(), "Failed to retrieve Rust based tokenizer".to_string(),