diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 935ab539..c25f06d5 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -42,6 +42,8 @@ namespace huggingface::tgi::backends { * @param topK * @param topP * @param temperature + * @param repetition_penalty + * @param frequency_penalty * @param seed * @return */ @@ -49,6 +51,8 @@ namespace huggingface::tgi::backends { uint32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed ); @@ -84,6 +88,8 @@ namespace huggingface::tgi::backends { * @param topK * @param topP * @param temperature + * @param repetition_penalty + * @param frequency_penalty * @param seed * @return Request id related to this generation for reference */ @@ -92,6 +98,8 @@ namespace huggingface::tgi::backends { int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed ); diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h index 9895382f..63bb4f19 100644 --- a/backends/trtllm/include/ffi.h +++ b/backends/trtllm/include/ffi.h @@ -40,12 +40,15 @@ namespace huggingface::tgi::backends { * @param topK * @param topP * @param temperature + * @param repetition_penalty + * @param frequency_penalty * @param seed * @return */ [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] uint64_t - Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed); + Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); /*** * diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 2b552113..26728241 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -57,6 +57,8 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( uint32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed) { return tle::SamplingConfig( 1, // TGI only use a single beam @@ -66,9 +68,12 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( std::nullopt, std::nullopt, seed, - std::nullopt, temperature, - std::nullopt + temperature, + std::nullopt, + repetition_penalty, + std::nullopt, + frequency_penalty ); } @@ -99,6 +104,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( const int32_t topK, const float_t topP, const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, const uint64_t seed ) { #ifdef NDEBUG @@ -118,7 +125,7 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>(); const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size())); - const auto sampling = GetSamplingConfig(topK, topP, temperature, seed); + const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); const auto output = tle::OutputConfig(true, false, false, true, false); return executor.enqueueRequest( tle::Request{tokens, maxNewTokens, true, sampling, output}); diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index c874ca64..1f94fd3a 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -140,6 +140,8 @@ impl TensorRtLlmBackend { top_k: u32, top_p: f32, temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, seed: u64, ) { let tokenizer = Arc::clone(&self.tokenizer); @@ -174,10 +176,15 @@ impl TensorRtLlmBackend { .in_scope(|| async { debug!("Acquiring lock for submit"); let mut handle = executor.write().await; - let request_id = - handle - .pin_mut() - .submit(&tokens, top_k as i32, top_p, temperature, seed); + let request_id = handle.pin_mut().submit( + &tokens, + top_k as i32, + top_p, + temperature, + repetition_penalty, + frequency_penalty, + seed, + ); debug!("Releasing lock for submit"); request_id diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index f3b68da8..a4433f2d 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -24,11 +24,13 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { } uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( - rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) { + rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed) { // This will copy all the items from the initial slice std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); - return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed); + return TensorRtLlmBackend::Submit( + std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); } size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 6506406d..d47a4b43 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -50,6 +50,8 @@ mod ffi { top_k: i32, top_p: f32, temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, seed: u64, ) -> u64;