mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
feat(trtllm): add stop words handling
# Conflicts: # backends/trtllm/lib/backend.cpp
This commit is contained in:
parent
c1a43a6c3e
commit
421a17544e
@ -5,6 +5,7 @@
|
||||
#ifndef TGI_TRTLLM_BACKEND_H
|
||||
#define TGI_TRTLLM_BACKEND_H
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <filesystem>
|
||||
#include <span>
|
||||
@ -72,6 +73,7 @@ namespace huggingface::tgi::backends {
|
||||
|
||||
/** Frequently accessed variables cached here **/
|
||||
uint32_t maxNumTokens;
|
||||
std::list<std::vector<TokenId>> stopWords;
|
||||
|
||||
public:
|
||||
explicit TensorRtLlmBackend(
|
||||
@ -91,20 +93,20 @@ namespace huggingface::tgi::backends {
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param repetitionPenalty
|
||||
* @param frequencyPenalty
|
||||
* @param seed
|
||||
* @return Request id related to this generation for reference
|
||||
*/
|
||||
[[nodiscard]] RequestId Submit(
|
||||
const std::vector<TokenId> &tokens,
|
||||
const uint32_t maxNewTokens,
|
||||
const int32_t topK,
|
||||
const float_t topP,
|
||||
const float_t temperature,
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed
|
||||
uint32_t maxNewTokens,
|
||||
int32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetitionPenalty,
|
||||
float_t frequencyPenalty,
|
||||
uint64_t seed
|
||||
);
|
||||
|
||||
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
|
||||
|
@ -104,6 +104,24 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
|
||||
// Cache variables
|
||||
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||
|
||||
// Attempt to discover stopWords from the generation_config.json
|
||||
if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) {
|
||||
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
|
||||
if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) {
|
||||
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
|
||||
stopWords = std::list<decltype(stopWords)::value_type>(eosTokenIds.size());
|
||||
|
||||
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(),
|
||||
[](const auto tokenIdObj) -> decltype(stopWords)::value_type {
|
||||
const auto tokenId = tokenIdObj.template get<tle::TokenIdType>();
|
||||
return {tokenId};
|
||||
});
|
||||
}
|
||||
} else {
|
||||
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
|
||||
stopWords = {};
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||
@ -124,8 +142,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const int32_t topK,
|
||||
const float_t topP,
|
||||
const float_t temperature,
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const float_t repetitionPenalty,
|
||||
const float_t frequencyPenalty,
|
||||
const uint64_t seed
|
||||
) {
|
||||
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
|
||||
@ -135,14 +153,20 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const auto &lastIteration = iterations.front();
|
||||
|
||||
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
|
||||
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
|
||||
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
|
||||
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
|
||||
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
|
||||
|
||||
// Build the request
|
||||
auto request = tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG};
|
||||
request.setStopWords(stopWords);
|
||||
|
||||
// Submit to the executor for batching
|
||||
return executor.enqueueRequest(request);
|
||||
}
|
||||
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
|
||||
|
Loading…
Reference in New Issue
Block a user