From 0da255ecbc155e080af43cc202938a04a76f4f44 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 10 Dec 2024 16:51:22 +0100 Subject: [PATCH] feat(trtllm): expose finish reason to Rust --- backends/trtllm/src/looper.rs | 12 ++++++++---- backends/trtllm/src/main.rs | 29 ++++------------------------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 969046d1..43f23242 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -10,7 +10,7 @@ use tokio::sync::TryAcquireError; use tokio::task::spawn_blocking; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; @@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{ EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, }; use text_generation_router::validation::{Chunk, ValidGenerateRequest}; -use text_generation_router::{FinishReason, Token}; +use text_generation_router::Token; use crate::errors::TensorRtLlmBackendError; -use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl}; +use crate::ffi::{ + create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl, +}; use crate::utils::first_line; type InferResult = Result; @@ -40,6 +42,7 @@ struct DecodedToken { id: u32, log_prob: f32, is_final: bool, + finish_reason: FinishReason, } impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { @@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { id: step.token_id, log_prob: step.log_prob, is_final: step.is_final, + finish_reason: step.finish_reason, }) } else { Err(GenerationError(step.error_msg.clone())) @@ -192,7 +196,7 @@ fn post_process_decoded_token( let generated_text = GeneratedText { text: text.unwrap(), generated_tokens: ctx.tokens.len() as u32, - finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason + finish_reason: decoded_token.finish_reason.into(), seed: None, }; diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 5af96ade..cef225be 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -67,11 +67,7 @@ struct Args { payload_limit: usize, } -async fn get_tokenizer( - tokenizer_name: &str, - _tokenizer_config_path: Option<&str>, - revision: Option<&str>, -) -> Option { +async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option { // Parse Huggingface hub token let authorization_token = std::env::var("HF_TOKEN") .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) @@ -182,19 +178,6 @@ async fn get_tokenizer( } }; - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - // let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path - // { - // HubTokenizerConfig::from_file(filename) - // } else { - // tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) - // }; - - // let tokenizer_config = tokenizer_config.unwrap_or_else(|| { - // tracing::warn!("Could not find tokenizer config locally and no API specified"); - // HubTokenizerConfig::default() - // }); - let tokenizer: Tokenizer = { use pyo3::prelude::*; pyo3::Python::with_gil(|py| -> PyResult<()> { @@ -292,13 +275,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { } // Create the backend - match get_tokenizer( - &tokenizer_name, - tokenizer_config_path.as_deref(), - revision.as_deref(), - ) - .await - .expect("Failed to retrieve tokenizer implementation") + match get_tokenizer(&tokenizer_name, revision.as_deref()) + .await + .expect("Failed to retrieve tokenizer implementation") { Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer( "Failed to retrieve Rust based tokenizer".to_string(),