mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-07 00:10:17 +00:00
feat(trtllm): cache maxNumTokens to avoid calling JSON everytime
This commit is contained in:
parent
31747163e7
commit
e6da212431
@ -24,6 +24,10 @@ namespace huggingface::tgi::backends {
|
|||||||
using TokenId = tle::TokenIdType;
|
using TokenId = tle::TokenIdType;
|
||||||
|
|
||||||
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
|
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
|
||||||
|
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
|
||||||
|
"Submitting inference [{}] to the executor ({:d} already in-flight)");
|
||||||
|
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
|
||||||
|
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize all the components required by TRTLLM.
|
* Initialize all the components required by TRTLLM.
|
||||||
@ -50,12 +54,12 @@ namespace huggingface::tgi::backends {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
tle::SamplingConfig GetSamplingConfig(
|
tle::SamplingConfig GetSamplingConfig(
|
||||||
const uint32_t topK,
|
uint32_t topK,
|
||||||
const float_t topP,
|
float_t topP,
|
||||||
const float_t temperature,
|
float_t temperature,
|
||||||
const float_t repetition_penalty,
|
float_t repetition_penalty,
|
||||||
const float_t frequency_penalty,
|
float_t frequency_penalty,
|
||||||
const uint64_t seed
|
uint64_t seed
|
||||||
) noexcept;
|
) noexcept;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -66,6 +70,9 @@ namespace huggingface::tgi::backends {
|
|||||||
const json config;
|
const json config;
|
||||||
tle::Executor executor;
|
tle::Executor executor;
|
||||||
|
|
||||||
|
/** Frequently accessed variables cached here **/
|
||||||
|
uint32_t maxNumTokens;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit TensorRtLlmBackend(
|
explicit TensorRtLlmBackend(
|
||||||
const std::filesystem::path &engineFolder,
|
const std::filesystem::path &engineFolder,
|
||||||
|
@ -75,6 +75,7 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
|||||||
const float_t repetition_penalty,
|
const float_t repetition_penalty,
|
||||||
const float_t frequency_penalty,
|
const float_t frequency_penalty,
|
||||||
const uint64_t seed) noexcept {
|
const uint64_t seed) noexcept {
|
||||||
|
|
||||||
return tle::SamplingConfig(
|
return tle::SamplingConfig(
|
||||||
1, // TGI only use a single beam
|
1, // TGI only use a single beam
|
||||||
topK,
|
topK,
|
||||||
@ -100,6 +101,9 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
|||||||
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||||
GetExecutorConfig(config, executorWorker.string())) {
|
GetExecutorConfig(config, executorWorker.string())) {
|
||||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||||
|
|
||||||
|
// Cache variables
|
||||||
|
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||||
@ -113,29 +117,22 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||||||
const float_t frequency_penalty,
|
const float_t frequency_penalty,
|
||||||
const uint64_t seed
|
const uint64_t seed
|
||||||
) {
|
) {
|
||||||
|
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
SPDLOG_DEBUG(
|
{
|
||||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
const auto &iterations = executor.getLatestIterationStats();
|
||||||
fmt::join(tokens, ", "),
|
const auto &lastIteration = iterations.front();
|
||||||
executor.getLatestIterationStats().front().numActiveRequests
|
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
|
||||||
);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint64_t>();
|
|
||||||
const auto maxNewTokensChecked = static_cast<tle::SizeType32>(
|
|
||||||
std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size())));
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
SPDLOG_INFO(
|
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
|
||||||
FMT_STRING(
|
}
|
||||||
"Sampling config: topK={:d}, topP={:d}, temperature={:d}, repetition_penalty={:d}, frequency_penalty={:d}, seed={:d}"),
|
|
||||||
topK, topP, temperature, repetition_penalty, frequency_penalty, seed
|
|
||||||
)
|
|
||||||
SPDLOG_INFO(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
|
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
|
||||||
|
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
|
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
|
||||||
|
Loading…
Reference in New Issue
Block a user