feat(trtllm): cache maxNumTokens to avoid calling JSON everytime

This commit is contained in:
Morgan Funtowicz 2024-10-21 14:51:58 +02:00
parent 31747163e7
commit e6da212431
2 changed files with 27 additions and 23 deletions

View File

@ -24,6 +24,10 @@ namespace huggingface::tgi::backends {
using TokenId = tle::TokenIdType;
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
"Submitting inference [{}] to the executor ({:d} already in-flight)");
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
/**
* Initialize all the components required by TRTLLM.
@ -50,12 +54,12 @@ namespace huggingface::tgi::backends {
* @return
*/
tle::SamplingConfig GetSamplingConfig(
const uint32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
) noexcept;
/**
@ -66,6 +70,9 @@ namespace huggingface::tgi::backends {
const json config;
tle::Executor executor;
/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;
public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,

View File

@ -75,6 +75,7 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed) noexcept {
return tle::SamplingConfig(
1, // TGI only use a single beam
topK,
@ -100,6 +101,9 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
// Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
}
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
@ -113,29 +117,22 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const float_t frequency_penalty,
const uint64_t seed
) {
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
#ifndef NDEBUG
SPDLOG_DEBUG(
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
fmt::join(tokens, ", "),
executor.getLatestIterationStats().front().numActiveRequests
);
#endif
{
const auto &iterations = executor.getLatestIterationStats();
const auto &lastIteration = iterations.front();
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint64_t>();
const auto maxNewTokensChecked = static_cast<tle::SizeType32>(
std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size())));
#ifndef NDEBUG
SPDLOG_INFO(
FMT_STRING(
"Sampling config: topK={:d}, topP={:d}, temperature={:d}, repetition_penalty={:d}, frequency_penalty={:d}, seed={:d}"),
topK, topP, temperature, repetition_penalty, frequency_penalty, seed
)
SPDLOG_INFO(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
}
#endif
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
}
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {