From 38b5263c61898b02949c25e24d0a77fc6948c3e0 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 2 Aug 2024 22:11:41 +0000 Subject: [PATCH] (ffi) add max_new_tokens parameters --- backends/trtllm/include/backend.h | 2 +- backends/trtllm/include/ffi.h | 4 +++- backends/trtllm/lib/backend.cpp | 16 +++++----------- backends/trtllm/src/ffi.cpp | 5 +++-- backends/trtllm/src/lib.rs | 1 + 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 3f89677c..bb31daa9 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -94,7 +94,7 @@ namespace huggingface::tgi::backends { * @return Request id related to this generation for reference */ [[nodiscard]] RequestId Submit( - const std::vector &tokens, + const std::vector &tokens, const uint32_t maxNewTokens, const int32_t topK, const float_t topP, diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h index fe0be9fc..df296918 100644 --- a/backends/trtllm/include/ffi.h +++ b/backends/trtllm/include/ffi.h @@ -37,6 +37,7 @@ namespace huggingface::tgi::backends { /*** * * @param tokens + * @param maxNewTokens * @param topK * @param topP * @param temperature @@ -47,7 +48,8 @@ namespace huggingface::tgi::backends { */ [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] uint64_t - Submit(rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, + Submit(rust::Slice tokens, uint32_t maxNewTokens, + int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); /*** diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 788b7674..dc9ffdaa 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -103,6 +103,7 @@ size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const [[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( const std::vector &tokens, + const uint32_t maxNewTokens, const int32_t topK, const float_t topP, const float_t temperature, @@ -124,19 +125,12 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( ); #endif - const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); - const auto maxNewTokens = static_cast(std::max(1ul, maxNumTokens - tokens.size())); + const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); + const auto maxNewTokensChecked = static_cast( + std::min(maxNewTokens, static_cast(maxNumTokens - tokens.size()))); const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); - const auto output = tle::OutputConfig(true, false, false, true, false); - return executor.enqueueRequest( - tle::Request{tokens, maxNewTokens, true, sampling, output}); -} - -[[nodiscard("Generated tokens result must be used")]] -std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { - SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId); - return executor.awaitResponses(requestId); + return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG}); } diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index d6317a68..beca88ad 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -25,8 +25,9 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { } uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( - rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, - float_t frequency_penalty, uint64_t seed) { + rust::Slice tokens, uint32_t maxNewTokens, + int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, float_t frequency_penalty, uint64_t seed) { // This will copy all the items from the initial slice std::vector tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 1a804f88..5253096c 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -56,6 +56,7 @@ mod ffi { fn Submit( self: Pin<&mut TensorRtLlmBackendImpl>, tokens: &[u32], + max_new_tokens: u32, top_k: i32, top_p: f32, temperature: f32,