mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 20:42:06 +00:00
chore(trtllm): define a macro for SizeType cast
This commit is contained in:
parent
7217cafadb
commit
d5c8bdc53b
@ -20,6 +20,9 @@
|
|||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
namespace tle = tensorrt_llm::executor;
|
namespace tle = tensorrt_llm::executor;
|
||||||
|
|
||||||
|
|
||||||
|
#define CAST_SIZETYPE(x) static_cast<tle::SizeType32>(x)
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
using RequestId = tle::IdType;
|
using RequestId = tle::IdType;
|
||||||
using TokenId = tle::TokenIdType;
|
using TokenId = tle::TokenIdType;
|
||||||
|
@ -164,10 +164,9 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
|
||||||
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
|
|
||||||
|
|
||||||
// Build the request
|
// Build the request
|
||||||
auto request = tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG};
|
auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
|
||||||
request.setStopWords(stopWords);
|
request.setStopWords(stopWords);
|
||||||
|
|
||||||
// Submit to the executor for batching
|
// Submit to the executor for batching
|
||||||
|
Loading…
Reference in New Issue
Block a user