From f4a74be384725980b80dc6c1da797c1ee9ba820c Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 2 Aug 2024 22:16:28 +0000 Subject: [PATCH] (backend) expose PullNewTokens --- backends/trtllm/include/backend.h | 20 +--------- backends/trtllm/include/ffi.h | 11 ++---- backends/trtllm/lib/backend.cpp | 17 ++------ backends/trtllm/src/ffi.cpp | 65 ++++++++++++++----------------- backends/trtllm/src/lib.rs | 28 ++++--------- 5 files changed, 47 insertions(+), 94 deletions(-) diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 9fda8f87b..83e862c55 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -56,7 +56,7 @@ namespace huggingface::tgi::backends { const float_t repetition_penalty, const float_t frequency_penalty, const uint64_t seed - ); + ) noexcept; /** * @@ -72,12 +72,6 @@ namespace huggingface::tgi::backends { const std::filesystem::path &executorWorker ); - /** - * Indicate if the backend is ready to accept incoming request - * @return true if ready, false otherwise - */ - [[nodiscard]] bool IsReady() const; - /** * Query the executor for the number of token available for pulling * @return @@ -106,17 +100,7 @@ namespace huggingface::tgi::backends { const uint64_t seed ); - /** - * - * @param requestId The request id to poll the generation results - * @return - */ - std::vector Poll(RequestId requestId); - - /** - * Stop the underlying executor - */ - void Shutdown(); + [[nodiscard]] std::vector PullNewTokens(); }; } diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h index df2969185..6127d29ac 100644 --- a/backends/trtllm/include/ffi.h +++ b/backends/trtllm/include/ffi.h @@ -54,18 +54,13 @@ namespace huggingface::tgi::backends { /*** * - * @param requestId - * @param ctx - * @param callback * @return */ - size_t StreamTokens( - const RequestId requestId, - huggingface::tgi::backends::GenerationContext *ctx, - rust::Fn callback); + std::unique_ptr> PullTokens(); }; + GenerationStep ConvertResponseToGenerationStep(const tle::Response &response); + /*** * * @param engineFolder diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 2eca477f5..9c9c5dff5 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -84,18 +84,11 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( const std::filesystem::path &executorWorker ) : config(json::parse(std::ifstream(enginesFolder / "config.json"))), - executor( - enginesFolder, - tensorrt_llm::executor::ModelType::kDECODER_ONLY, - GetExecutorConfig(config, executorWorker.string() - )) { + executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY, + GetExecutorConfig(config, executorWorker.string())) { SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref()); } -bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { - return executor.canEnqueueRequests(); -} - [[nodiscard("Returned number of requests needs to be consumed")]] size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { return executor.getNumResponsesReady(); @@ -134,8 +127,6 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG}); } - -void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { - SPDLOG_INFO("Shutting down executor"); - executor.shutdown(); +std::vector huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() { + return std::move(executor.awaitResponses()); } diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index beca88ad9..e55204ab1 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -35,47 +35,42 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); } -size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( - const uint64_t requestId, - huggingface::tgi::backends::GenerationContext *ctx, - rust::Fn callback) { +std::unique_ptr> +huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() { + const auto responses = TensorRtLlmBackend::PullNewTokens(); + auto steps = std::make_unique>(responses.size()); + std::ranges::copy(std::views::transform(responses, ConvertResponseToGenerationStep), std::back_inserter(*steps)); + return steps; +} - size_t numTokens = 0; - for (const auto &item: Poll(requestId)) { - GenerationStep step; - if (!item.hasError()) { - SPDLOG_DEBUG("\tStreamTokens -> Decoding token..."); - const auto decoded = item.getResult(); - - const auto token = decoded.outputTokenIds[0][0]; - const auto isFinal = decoded.isFinal; - const auto logProb = decoded.logProbs.value()[0][0]; - - ++numTokens; - - SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); - step = huggingface::tgi::backends::GenerationStep{ - static_cast(token), logProb, isFinal, false, std::move(std::string()) - }; - SPDLOG_DEBUG("\tStreamTokens -> Post callback"); - } else { - // TODO : Return rest::Result with error - const auto what = item.getErrorMsg(); - SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); - step = huggingface::tgi::backends::GenerationStep{ - std::numeric_limits::max(), 0.0, true, true, std::move(what) - }; - } - - callback(std::move(ctx), std::move(step)); +huggingface::tgi::backends::GenerationStep +huggingface::tgi::backends::ConvertResponseToGenerationStep(const tle::Response &response) { + const auto reqId = response.getRequestId(); + if (!response.hasError()) { + const auto result = response.getResult(); + return std::move(GenerationStep{ + reqId, + result.outputTokenIds[0][0], + result.logProbs.value()[0][0], + result.isFinal, + false, + std::string() + }); + } else { + return std::move(GenerationStep{ + reqId, + 0, + 0.0, + true, + true, + std::move(response.getErrorMsg()) + }); } - - return numTokens; } std::unique_ptr huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { + SPDLOG_INFO("Creating TensorRT-LLM Backend"); // Unconditionally call this to initialize and discover TRTLLM plugins InitializeBackend(); diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 5253096cc..00a510a77 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -1,7 +1,7 @@ -pub use backend::{GenerationContext, TensorRtLlmBackend}; +pub use looper::TensorRtLlmBackendV2; -mod backend; pub mod errors; +mod looper; #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { @@ -9,6 +9,7 @@ mod ffi { /// Struct used as shared type between rust and C++ to represent the result /// of a single decoding iteration pub struct GenerationStep { + request_id: u64, token_id: u32, log_prob: f32, is_final: bool, @@ -16,10 +17,6 @@ mod ffi { error_msg: String, } - extern "Rust" { - type GenerationContext; - } - unsafe extern "C++" { include!("backends/trtllm/src/ffi.cpp"); @@ -44,10 +41,7 @@ mod ffi { fn CreateTensorRtLlmBackend( engine_folder: &str, executor_worker: &str, - ) -> UniquePtr; - - // #[rust_name = "is_ready"] - // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; + ) -> Result>; #[rust_name = "num_responses_ready"] fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; @@ -63,17 +57,11 @@ mod ffi { repetition_penalty: f32, frequency_penalty: f32, seed: u64, - ) -> u64; + ) -> Result; - #[rust_name = "stream_tokens"] - unsafe fn StreamTokens( + #[rust_name = "pull_tokens"] + fn PullTokens( self: Pin<&mut TensorRtLlmBackendImpl>, - request_id: u64, - ctx: *mut GenerationContext, - cb: unsafe fn(*mut GenerationContext, GenerationStep), - ) -> usize; - - // #[rust_name = "shutdown"] - // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); + ) -> Result>>; } }