From 344f33f398a3d9d2ae7afd31ab783396f360ef77 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 12 Jul 2024 19:25:40 +0000 Subject: [PATCH] end to end ffi flow working --- backends/trtllm/include/ffi.h | 69 +++++++++++++++++ backends/trtllm/src/ffi.cpp | 137 ++++++++++++++-------------------- backends/trtllm/src/lib.rs | 23 ++++-- 3 files changed, 139 insertions(+), 90 deletions(-) create mode 100644 backends/trtllm/include/ffi.h diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h new file mode 100644 index 00000000..6b77b7b4 --- /dev/null +++ b/backends/trtllm/include/ffi.h @@ -0,0 +1,69 @@ +// +// Created by mfuntowicz on 7/11/24. +// + +#ifndef TGI_TRTLLM_BACKEND_FFI_H +#define TGI_TRTLLM_BACKEND_FFI_H + +//#include "rust/cxx.h" +#include "backend.h" + +namespace huggingface::tgi::backends { + class TensorRtLlmBackendImpl; +} + +#include "backends/trtllm/src/lib.rs.h" + + +namespace huggingface::tgi::backends { + + struct GenerationContext; + + class TensorRtLlmBackendImpl : TensorRtLlmBackend { + public: + /*** + * + * @param engineFolder + * @param executorWorker + */ + TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); + + /*** + * + * @return + */ + bool IsReady() const; + + /*** + * + * @param tokens + * @param maxNewTokens + * @param topK + * @param topP + * @param temperature + * @param seed + * @return + */ + [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] + uint64_t Submit(rust::Slice tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed); + + /*** + * + * @param requestId + * @param handler + * @return + */ + uint32_t Stream(rust::Box ctx, + uint64_t requestId, + rust::Fn, uint32_t, uint32_t, bool)> handler); + }; + + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker); +} + +#endif //TGI_TRTLLM_BACKEND_FFI_H diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index 06fc3623..f1030c8f 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -7,93 +7,66 @@ #include #include -#include "rust/cxx.h" -#include "backends/trtllm/include/backend.h" +//#include "rust/cxx.h" +//#include "../include/ffi.h" +#include "backends/trtllm/include/ffi.h" -namespace huggingface::tgi::backends { - class TensorRtLlmBackendImpl : TensorRtLlmBackend { - public: - /*** - * - * @param engineFolder - * @param executorWorker - */ - TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker) : - TensorRtLlmBackend(std::move(engineFolder), std::move(executorWorker)) {} - /*** - * - * @return - */ - bool IsReady() const { return TensorRtLlmBackend::IsReady(); } +huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl( + const std::string_view &engineFolder, + const std::string_view &executorWorker +) : TensorRtLlmBackend(engineFolder, executorWorker) {} - /*** - * - * @param tokens - * @param maxNewTokens - * @param topK - * @param topP - * @param temperature - * @param seed - * @return - */ - [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] - RequestId Submit(rust::Slice tokens, - int32_t maxNewTokens, - int32_t topK, - float_t topP, - float_t temperature, - uint64_t seed) { - // This will copy all the items from the initial slice - std::vector tokens_(tokens.size()); - tokens_.assign(tokens.begin(), tokens.end()); - return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed); +bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { + return TensorRtLlmBackend::IsReady(); +} + +uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( + rust::Slice tokens, + int32_t maxNewTokens, int32_t topK, float_t topP, + float_t temperature, uint64_t seed) { + + // This will copy all the items from the initial slice + std::vector tokens_(tokens.size()); + tokens_.assign(tokens.begin(), tokens.end()); + + return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed); +} + +uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream( + rust::Box ctx, + uint64_t requestId, + rust::Fn, uint32_t, uint32_t, bool)> handler) { + bool isDone = false; + uint32_t numGeneratedTokens = 0; + + do { + const auto responses = Poll(requestId); + for (const auto &response: responses) { + if (response.hasError()) { + isDone = true; + // TODO : bubble up the error to rust + } else { + const auto generation = response.getResult(); + const auto token = generation.outputTokenIds[0][0]; + isDone = generation.isFinal; + + // Propagate through the handler + handler(std::move(ctx), token, numGeneratedTokens, isDone); + } } + } while (!isDone); - /*** - * - * @param requestId - * @param handler - * @return - */ -// uint32_t -// Stream(RequestId requestId, rust::Box , rust::Fn handler) { -// bool isDone = false; -// uint32_t numGeneratedTokens = 0; -// -// do { -// const auto responses = Poll(requestId); -// for (const auto &response: responses) { -// if (response.hasError()) { -// isDone = true; -// // TODO : bubble up the error to rust -// } else { -// const auto generation = response.getResult(); -// const auto token = generation.outputTokenIds[0][0]; -// isDone = generation.isFinal; -// -// // Propagate through the handler -// handler(token, numGeneratedTokens, isDone); -// } -// } -// } while (!isDone); -// -// return numGeneratedTokens; -// } - }; + return numGeneratedTokens; +} - /*** - * - * @param engineFolder - * @return - */ - std::unique_ptr create_trtllm_backend(rust::Str engineFolder, rust::Str executorWorker) { - // Unconditionally call this to initialize and discover TRTLLM plugins - InitializeBackend(); +std::unique_ptr +huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { + // Unconditionally call this to initialize and discover TRTLLM plugins + InitializeBackend(); - const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); - const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); - return std::make_unique(std::move(enginePath), std::move(executorPath)); - } -} \ No newline at end of file + const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); + const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); + return std::make_unique(std::move(enginePath), std::move(executorPath)); +} diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index d2838099..fbc174b2 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -1,10 +1,16 @@ pub use backend::TrtLLmBackend; +use crate::backend::GenerationContext; + mod backend; pub mod errors; #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { + extern "Rust" { + type GenerationContext; + } + unsafe extern "C++" { include!("backends/trtllm/src/ffi.cpp"); @@ -25,7 +31,8 @@ mod ffi { /// ``` /// /// ``` - fn create_trtllm_backend( + #[rust_name = "create_tensorrt_llm_backend"] + fn CreateTensorRtLlmBackend( engine_folder: &str, executor_worker: &str, ) -> UniquePtr; @@ -44,12 +51,12 @@ mod ffi { seed: u64, ) -> u64; - // #[rust_name = "stream"] - // fn Stream( - // self: Pin<&mut TensorRtLlmBackendImpl>, - // request_id: u64, - // ctx: Box, - // callback: fn(u32, u32, bool), - // ) -> u32; + #[rust_name = "stream"] + fn Stream( + self: Pin<&mut TensorRtLlmBackendImpl>, + ctx: Box, + request_id: u64, + callback: fn(Box, u32, u32, bool), + ) -> u32; } }