(ffi) use const for GetSamplingConfig

This commit is contained in:
Morgan Funtowicz 2024-08-01 07:49:37 +00:00 committed by Morgan Funtowicz
parent cea64e234f
commit 0cd7538a48
2 changed files with 21 additions and 20 deletions

View File

@ -48,12 +48,12 @@ namespace huggingface::tgi::backends {
* @return * @return
*/ */
tle::SamplingConfig GetSamplingConfig( tle::SamplingConfig GetSamplingConfig(
uint32_t topK, const uint32_t topK,
float_t topP, const float_t topP,
float_t temperature, const float_t temperature,
float_t repetition_penalty, const float_t repetition_penalty,
float_t frequency_penalty, const float_t frequency_penalty,
uint64_t seed const uint64_t seed
); );
/** /**
@ -94,13 +94,14 @@ namespace huggingface::tgi::backends {
* @return Request id related to this generation for reference * @return Request id related to this generation for reference
*/ */
[[nodiscard]] RequestId Submit( [[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens, const std::vector <TokenId> &tokens,
int32_t topK, const uint32_t maxNewTokens,
float_t topP, const int32_t topK,
float_t temperature, const float_t topP,
float_t repetition_penalty, const float_t temperature,
float_t frequency_penalty, const float_t repetition_penalty,
uint64_t seed const float_t frequency_penalty,
const uint64_t seed
); );
/** /**
@ -108,7 +109,7 @@ namespace huggingface::tgi::backends {
* @param requestId The request id to poll the generation results * @param requestId The request id to poll the generation results
* @return * @return
*/ */
std::vector<tle::Response> Poll(RequestId requestId); std::vector <tle::Response> Poll(RequestId requestId);
/** /**
* Stop the underlying executor * Stop the underlying executor

View File

@ -55,12 +55,12 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
} }
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK, const uint32_t topK,
float_t topP, const float_t topP,
float_t temperature, const float_t temperature,
float_t repetition_penalty, const float_t repetition_penalty,
float_t frequency_penalty, const float_t frequency_penalty,
uint64_t seed) { const uint64_t seed) {
return tle::SamplingConfig( return tle::SamplingConfig(
1, // TGI only use a single beam 1, // TGI only use a single beam
topK, topK,