diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 7990e76b..3f89677c 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -48,12 +48,12 @@ namespace huggingface::tgi::backends { * @return */ tle::SamplingConfig GetSamplingConfig( - uint32_t topK, - float_t topP, - float_t temperature, - float_t repetition_penalty, - float_t frequency_penalty, - uint64_t seed + const uint32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed ); /** @@ -94,13 +94,14 @@ namespace huggingface::tgi::backends { * @return Request id related to this generation for reference */ [[nodiscard]] RequestId Submit( - const std::vector &tokens, - int32_t topK, - float_t topP, - float_t temperature, - float_t repetition_penalty, - float_t frequency_penalty, - uint64_t seed + const std::vector &tokens, + const uint32_t maxNewTokens, + const int32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed ); /** @@ -108,7 +109,7 @@ namespace huggingface::tgi::backends { * @param requestId The request id to poll the generation results * @return */ - std::vector Poll(RequestId requestId); + std::vector Poll(RequestId requestId); /** * Stop the underlying executor diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index c066a6d6..788b7674 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -55,12 +55,12 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co } tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( - uint32_t topK, - float_t topP, - float_t temperature, - float_t repetition_penalty, - float_t frequency_penalty, - uint64_t seed) { + const uint32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed) { return tle::SamplingConfig( 1, // TGI only use a single beam topK,