feat(trtllm): add stop words handling

# Conflicts:
#	backends/trtllm/lib/backend.cpp
This commit is contained in:
Morgan Funtowicz 2024-10-21 17:00:45 +02:00
parent c1a43a6c3e
commit 421a17544e
2 changed files with 40 additions and 14 deletions

View File

@ -5,6 +5,7 @@
#ifndef TGI_TRTLLM_BACKEND_H #ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H #define TGI_TRTLLM_BACKEND_H
#include <array>
#include <cmath> #include <cmath>
#include <filesystem> #include <filesystem>
#include <span> #include <span>
@ -72,6 +73,7 @@ namespace huggingface::tgi::backends {
/** Frequently accessed variables cached here **/ /** Frequently accessed variables cached here **/
uint32_t maxNumTokens; uint32_t maxNumTokens;
std::list<std::vector<TokenId>> stopWords;
public: public:
explicit TensorRtLlmBackend( explicit TensorRtLlmBackend(
@ -91,20 +93,20 @@ namespace huggingface::tgi::backends {
* @param topK * @param topK
* @param topP * @param topP
* @param temperature * @param temperature
* @param repetition_penalty * @param repetitionPenalty
* @param frequency_penalty * @param frequencyPenalty
* @param seed * @param seed
* @return Request id related to this generation for reference * @return Request id related to this generation for reference
*/ */
[[nodiscard]] RequestId Submit( [[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens, const std::vector<TokenId> &tokens,
const uint32_t maxNewTokens, uint32_t maxNewTokens,
const int32_t topK, int32_t topK,
const float_t topP, float_t topP,
const float_t temperature, float_t temperature,
const float_t repetition_penalty, float_t repetitionPenalty,
const float_t frequency_penalty, float_t frequencyPenalty,
const uint64_t seed uint64_t seed
); );
[[nodiscard]] std::vector<tle::Response> PullNewTokens(); [[nodiscard]] std::vector<tle::Response> PullNewTokens();

View File

@ -104,6 +104,24 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
// Cache variables // Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>(); maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
// Attempt to discover stopWords from the generation_config.json
if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
stopWords = std::list<decltype(stopWords)::value_type>(eosTokenIds.size());
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(),
[](const auto tokenIdObj) -> decltype(stopWords)::value_type {
const auto tokenId = tokenIdObj.template get<tle::TokenIdType>();
return {tokenId};
});
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
stopWords = {};
}
} }
[[nodiscard("Returned number of requests needs to be consumed")]] [[nodiscard("Returned number of requests needs to be consumed")]]
@ -124,8 +142,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const int32_t topK, const int32_t topK,
const float_t topP, const float_t topP,
const float_t temperature, const float_t temperature,
const float_t repetition_penalty, const float_t repetitionPenalty,
const float_t frequency_penalty, const float_t frequencyPenalty,
const uint64_t seed const uint64_t seed
) { ) {
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size())); const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
@ -135,14 +153,20 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const auto &lastIteration = iterations.front(); const auto &lastIteration = iterations.front();
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests); SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed); SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked); SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
} }
#endif #endif
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked); const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
// Build the request
auto request = tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG};
request.setStopWords(stopWords);
// Submit to the executor for batching
return executor.enqueueRequest(request);
} }
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() { std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {