mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
feat(trtllm): cache maxNumTokens to avoid calling JSON everytime
This commit is contained in:
parent
31747163e7
commit
e6da212431
@ -24,6 +24,10 @@ namespace huggingface::tgi::backends {
|
||||
using TokenId = tle::TokenIdType;
|
||||
|
||||
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
|
||||
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
|
||||
"Submitting inference [{}] to the executor ({:d} already in-flight)");
|
||||
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
|
||||
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
|
||||
|
||||
/**
|
||||
* Initialize all the components required by TRTLLM.
|
||||
@ -50,12 +54,12 @@ namespace huggingface::tgi::backends {
|
||||
* @return
|
||||
*/
|
||||
tle::SamplingConfig GetSamplingConfig(
|
||||
const uint32_t topK,
|
||||
const float_t topP,
|
||||
const float_t temperature,
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed
|
||||
uint32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed
|
||||
) noexcept;
|
||||
|
||||
/**
|
||||
@ -66,6 +70,9 @@ namespace huggingface::tgi::backends {
|
||||
const json config;
|
||||
tle::Executor executor;
|
||||
|
||||
/** Frequently accessed variables cached here **/
|
||||
uint32_t maxNumTokens;
|
||||
|
||||
public:
|
||||
explicit TensorRtLlmBackend(
|
||||
const std::filesystem::path &engineFolder,
|
||||
|
@ -75,6 +75,7 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed) noexcept {
|
||||
|
||||
return tle::SamplingConfig(
|
||||
1, // TGI only use a single beam
|
||||
topK,
|
||||
@ -100,6 +101,9 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||
GetExecutorConfig(config, executorWorker.string())) {
|
||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||
|
||||
// Cache variables
|
||||
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||
@ -113,29 +117,22 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed
|
||||
) {
|
||||
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
|
||||
#ifndef NDEBUG
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
||||
fmt::join(tokens, ", "),
|
||||
executor.getLatestIterationStats().front().numActiveRequests
|
||||
);
|
||||
#endif
|
||||
{
|
||||
const auto &iterations = executor.getLatestIterationStats();
|
||||
const auto &lastIteration = iterations.front();
|
||||
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
|
||||
|
||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint64_t>();
|
||||
const auto maxNewTokensChecked = static_cast<tle::SizeType32>(
|
||||
std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size())));
|
||||
|
||||
#ifndef NDEBUG
|
||||
SPDLOG_INFO(
|
||||
FMT_STRING(
|
||||
"Sampling config: topK={:d}, topP={:d}, temperature={:d}, repetition_penalty={:d}, frequency_penalty={:d}, seed={:d}"),
|
||||
topK, topP, temperature, repetition_penalty, frequency_penalty, seed
|
||||
)
|
||||
SPDLOG_INFO(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
|
||||
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
|
||||
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
|
||||
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
|
||||
}
|
||||
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
|
||||
|
Loading…
Reference in New Issue
Block a user