From 421a17544ee6854411ad2950d0bcf311826da2a1 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 21 Oct 2024 17:00:45 +0200 Subject: [PATCH] feat(trtllm): add stop words handling # Conflicts: # backends/trtllm/lib/backend.cpp --- backends/trtllm/include/backend.h | 20 ++++++++++-------- backends/trtllm/lib/backend.cpp | 34 ++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 5b2963a8..1793b2dd 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -5,6 +5,7 @@ #ifndef TGI_TRTLLM_BACKEND_H #define TGI_TRTLLM_BACKEND_H +#include #include #include #include @@ -72,6 +73,7 @@ namespace huggingface::tgi::backends { /** Frequently accessed variables cached here **/ uint32_t maxNumTokens; + std::list> stopWords; public: explicit TensorRtLlmBackend( @@ -91,20 +93,20 @@ namespace huggingface::tgi::backends { * @param topK * @param topP * @param temperature - * @param repetition_penalty - * @param frequency_penalty + * @param repetitionPenalty + * @param frequencyPenalty * @param seed * @return Request id related to this generation for reference */ [[nodiscard]] RequestId Submit( const std::vector &tokens, - const uint32_t maxNewTokens, - const int32_t topK, - const float_t topP, - const float_t temperature, - const float_t repetition_penalty, - const float_t frequency_penalty, - const uint64_t seed + uint32_t maxNewTokens, + int32_t topK, + float_t topP, + float_t temperature, + float_t repetitionPenalty, + float_t frequencyPenalty, + uint64_t seed ); [[nodiscard]] std::vector PullNewTokens(); diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 72a75e2a..2750b423 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -104,6 +104,24 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( // Cache variables maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); + + // Attempt to discover stopWords from the generation_config.json + if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) { + const auto generationConfig = json::parse(std::ifstream(generationConfigPath)); + if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) { + SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size()); + stopWords = std::list(eosTokenIds.size()); + + std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), + [](const auto tokenIdObj) -> decltype(stopWords)::value_type { + const auto tokenId = tokenIdObj.template get(); + return {tokenId}; + }); + } + } else { + SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist"); + stopWords = {}; + } } [[nodiscard("Returned number of requests needs to be consumed")]] @@ -124,8 +142,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( const int32_t topK, const float_t topP, const float_t temperature, - const float_t repetition_penalty, - const float_t frequency_penalty, + const float_t repetitionPenalty, + const float_t frequencyPenalty, const uint64_t seed ) { const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast(maxNumTokens - tokens.size())); @@ -135,14 +153,20 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( const auto &lastIteration = iterations.front(); SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests); - SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed); + SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed); SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked); } #endif - const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); + const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed); const auto maxNewTokensChecked_ = static_cast(maxNewTokensChecked); - return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG}); + + // Build the request + auto request = tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG}; + request.setStopWords(stopWords); + + // Submit to the executor for batching + return executor.enqueueRequest(request); } std::vector huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {