mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
[TRTLLM] Expose finish reason (#2841)
* feat(trtllm): expose finish reason to Rust * misc(llamacpp): fix typo * misc(backend): update deps
This commit is contained in:
parent
4e172028aa
commit
0a89902663
@ -30,7 +30,7 @@ option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
|||||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
|
||||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||||
|
|
||||||
|
@ -7,20 +7,16 @@ homepage.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
#async-stream = "0.3"
|
|
||||||
clap = { version = "4.5", features = ["derive"] }
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
hashbrown = "0.14"
|
hashbrown = "0.15"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
#log = { version = "0.4", features = [] }
|
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.15"
|
tokio-stream = "0.1.17"
|
||||||
thiserror = "1.0.63"
|
thiserror = "1.0.63"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
#tracing-opentelemetry = "0.25"
|
|
||||||
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
|
||||||
pyo3 = { workspace = true }
|
pyo3 = { workspace = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
|
|||||||
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||||
};
|
};
|
||||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::Token;
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl};
|
use crate::ffi::{
|
||||||
|
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
|
||||||
|
};
|
||||||
use crate::utils::first_line;
|
use crate::utils::first_line;
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
type InferResult<T> = Result<T, InferError>;
|
||||||
@ -40,6 +42,7 @@ struct DecodedToken {
|
|||||||
id: u32,
|
id: u32,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
|
finish_reason: FinishReason,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||||
@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
|||||||
id: step.token_id,
|
id: step.token_id,
|
||||||
log_prob: step.log_prob,
|
log_prob: step.log_prob,
|
||||||
is_final: step.is_final,
|
is_final: step.is_final,
|
||||||
|
finish_reason: step.finish_reason,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(GenerationError(step.error_msg.clone()))
|
Err(GenerationError(step.error_msg.clone()))
|
||||||
@ -192,7 +196,7 @@ fn post_process_decoded_token(
|
|||||||
let generated_text = GeneratedText {
|
let generated_text = GeneratedText {
|
||||||
text: text.unwrap(),
|
text: text.unwrap(),
|
||||||
generated_tokens: ctx.tokens.len() as u32,
|
generated_tokens: ctx.tokens.len() as u32,
|
||||||
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
|
finish_reason: decoded_token.finish_reason.into(),
|
||||||
seed: None,
|
seed: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -67,11 +67,7 @@ struct Args {
|
|||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer(
|
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
|
||||||
tokenizer_name: &str,
|
|
||||||
_tokenizer_config_path: Option<&str>,
|
|
||||||
revision: Option<&str>,
|
|
||||||
) -> Option<Tokenizer> {
|
|
||||||
// Parse Huggingface hub token
|
// Parse Huggingface hub token
|
||||||
let authorization_token = std::env::var("HF_TOKEN")
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
@ -182,19 +178,6 @@ async fn get_tokenizer(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
// {
|
|
||||||
// HubTokenizerConfig::from_file(filename)
|
|
||||||
// } else {
|
|
||||||
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
// };
|
|
||||||
|
|
||||||
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
// tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
// HubTokenizerConfig::default()
|
|
||||||
// });
|
|
||||||
|
|
||||||
let tokenizer: Tokenizer = {
|
let tokenizer: Tokenizer = {
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||||
@ -292,13 +275,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the backend
|
// Create the backend
|
||||||
match get_tokenizer(
|
match get_tokenizer(&tokenizer_name, revision.as_deref())
|
||||||
&tokenizer_name,
|
.await
|
||||||
tokenizer_config_path.as_deref(),
|
.expect("Failed to retrieve tokenizer implementation")
|
||||||
revision.as_deref(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.expect("Failed to retrieve tokenizer implementation")
|
|
||||||
{
|
{
|
||||||
Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
|
Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
|
||||||
"Failed to retrieve Rust based tokenizer".to_string(),
|
"Failed to retrieve Rust based tokenizer".to_string(),
|
||||||
|
Loading…
Reference in New Issue
Block a user