mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
[TRTLLM] Expose finish reason (#2841)
* feat(trtllm): expose finish reason to Rust * misc(llamacpp): fix typo * misc(backend): update deps
This commit is contained in:
parent
4e172028aa
commit
0a89902663
@ -30,7 +30,7 @@ option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
|
||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||
|
||||
|
@ -7,20 +7,16 @@ homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
#async-stream = "0.3"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
cxx = "1.0"
|
||||
hashbrown = "0.14"
|
||||
hashbrown = "0.15"
|
||||
hf-hub = { workspace = true }
|
||||
#log = { version = "0.4", features = [] }
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.15"
|
||||
tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.17"
|
||||
thiserror = "1.0.63"
|
||||
tracing = "0.1"
|
||||
#tracing-opentelemetry = "0.25"
|
||||
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
pyo3 = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
|
||||
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||
};
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use text_generation_router::Token;
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl};
|
||||
use crate::ffi::{
|
||||
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
|
||||
};
|
||||
use crate::utils::first_line;
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
@ -40,6 +42,7 @@ struct DecodedToken {
|
||||
id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
finish_reason: FinishReason,
|
||||
}
|
||||
|
||||
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||
@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||
id: step.token_id,
|
||||
log_prob: step.log_prob,
|
||||
is_final: step.is_final,
|
||||
finish_reason: step.finish_reason,
|
||||
})
|
||||
} else {
|
||||
Err(GenerationError(step.error_msg.clone()))
|
||||
@ -192,7 +196,7 @@ fn post_process_decoded_token(
|
||||
let generated_text = GeneratedText {
|
||||
text: text.unwrap(),
|
||||
generated_tokens: ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
|
||||
finish_reason: decoded_token.finish_reason.into(),
|
||||
seed: None,
|
||||
};
|
||||
|
||||
|
@ -67,11 +67,7 @@ struct Args {
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
async fn get_tokenizer(
|
||||
tokenizer_name: &str,
|
||||
_tokenizer_config_path: Option<&str>,
|
||||
revision: Option<&str>,
|
||||
) -> Option<Tokenizer> {
|
||||
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
|
||||
// Parse Huggingface hub token
|
||||
let authorization_token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
@ -182,19 +178,6 @@ async fn get_tokenizer(
|
||||
}
|
||||
};
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
// {
|
||||
// HubTokenizerConfig::from_file(filename)
|
||||
// } else {
|
||||
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
// };
|
||||
|
||||
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
// tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
// HubTokenizerConfig::default()
|
||||
// });
|
||||
|
||||
let tokenizer: Tokenizer = {
|
||||
use pyo3::prelude::*;
|
||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
@ -292,11 +275,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
}
|
||||
|
||||
// Create the backend
|
||||
match get_tokenizer(
|
||||
&tokenizer_name,
|
||||
tokenizer_config_path.as_deref(),
|
||||
revision.as_deref(),
|
||||
)
|
||||
match get_tokenizer(&tokenizer_name, revision.as_deref())
|
||||
.await
|
||||
.expect("Failed to retrieve tokenizer implementation")
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user