From 31d9f4d5dcfaad055b3ec6bbe25f737378966f83 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 15 Jul 2024 07:36:01 +0000 Subject: [PATCH] expose shutdown function at ffi layer --- backends/trtllm/include/backend.h | 5 +++++ backends/trtllm/lib/backend.cpp | 6 ++++++ backends/trtllm/src/ffi.cpp | 2 -- backends/trtllm/src/lib.rs | 5 ++++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 12347ef4..90df0fae 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -92,6 +92,11 @@ namespace huggingface::tgi::backends { * @return Global number of generated tokens for this request id */ uint32_t Stream(RequestId reqId, std::function &cb); + + /*** + * Stop the underlying executor + */ + void Shutdown(); }; } diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index b7a0bb4e..1db73651 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -136,3 +136,9 @@ uint32_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdTyp std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { return executor.awaitResponses(requestId); } + + +void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { + SPDLOG_INFO("Shutting down executor"); + executor.shutdown(); +} \ No newline at end of file diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index f1030c8f..47e73a6f 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -7,8 +7,6 @@ #include #include -//#include "rust/cxx.h" -//#include "../include/ffi.h" #include "backends/trtllm/include/ffi.h" diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index fbc174b2..ef4b2907 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -1,4 +1,4 @@ -pub use backend::TrtLLmBackend; +pub use backend::TensorRtLlmBackend; use crate::backend::GenerationContext; @@ -58,5 +58,8 @@ mod ffi { request_id: u64, callback: fn(Box, u32, u32, bool), ) -> u32; + + #[rust_name = "shutdown"] + fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); } }