mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-10 19:32:06 +00:00
(ffi) add max_new_tokens parameters
This commit is contained in:
parent
f6f689f509
commit
38b5263c61
@ -94,7 +94,7 @@ namespace huggingface::tgi::backends {
|
|||||||
* @return Request id related to this generation for reference
|
* @return Request id related to this generation for reference
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] RequestId Submit(
|
[[nodiscard]] RequestId Submit(
|
||||||
const std::vector <TokenId> &tokens,
|
const std::vector<TokenId> &tokens,
|
||||||
const uint32_t maxNewTokens,
|
const uint32_t maxNewTokens,
|
||||||
const int32_t topK,
|
const int32_t topK,
|
||||||
const float_t topP,
|
const float_t topP,
|
||||||
|
@ -37,6 +37,7 @@ namespace huggingface::tgi::backends {
|
|||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param tokens
|
* @param tokens
|
||||||
|
* @param maxNewTokens
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
@ -47,7 +48,8 @@ namespace huggingface::tgi::backends {
|
|||||||
*/
|
*/
|
||||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
uint64_t
|
uint64_t
|
||||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
|
||||||
|
int32_t topK, float_t topP, float_t temperature,
|
||||||
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||||
|
|
||||||
/***
|
/***
|
||||||
|
@ -103,6 +103,7 @@ size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const
|
|||||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
const std::vector<tle::TokenIdType> &tokens,
|
const std::vector<tle::TokenIdType> &tokens,
|
||||||
|
const uint32_t maxNewTokens,
|
||||||
const int32_t topK,
|
const int32_t topK,
|
||||||
const float_t topP,
|
const float_t topP,
|
||||||
const float_t temperature,
|
const float_t temperature,
|
||||||
@ -124,19 +125,12 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||||||
);
|
);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint64_t>();
|
||||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
const auto maxNewTokensChecked = static_cast<tle::SizeType32>(
|
||||||
|
std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size())));
|
||||||
|
|
||||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
const auto output = tle::OutputConfig(true, false, false, true, false);
|
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
|
||||||
return executor.enqueueRequest(
|
|
||||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard("Generated tokens result must be used")]]
|
|
||||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
|
||||||
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
|
|
||||||
return executor.awaitResponses(requestId);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,8 +25,9 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
|
||||||
float_t frequency_penalty, uint64_t seed) {
|
int32_t topK, float_t topP, float_t temperature,
|
||||||
|
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed) {
|
||||||
|
|
||||||
// This will copy all the items from the initial slice
|
// This will copy all the items from the initial slice
|
||||||
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
||||||
|
@ -56,6 +56,7 @@ mod ffi {
|
|||||||
fn Submit(
|
fn Submit(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
tokens: &[u32],
|
tokens: &[u32],
|
||||||
|
max_new_tokens: u32,
|
||||||
top_k: i32,
|
top_k: i32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
|
Loading…
Reference in New Issue
Block a user