diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 12347ef4..90df0fae 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -92,6 +92,11 @@ namespace huggingface::tgi::backends { * @return Global number of generated tokens for this request id */ uint32_t Stream(RequestId reqId, std::function &cb); + + /*** + * Stop the underlying executor + */ + void Shutdown(); }; } diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index b7a0bb4e..1db73651 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -136,3 +136,9 @@ uint32_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdTyp std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { return executor.awaitResponses(requestId); } + + +void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { + SPDLOG_INFO("Shutting down executor"); + executor.shutdown(); +} \ No newline at end of file diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index f1030c8f..47e73a6f 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -7,8 +7,6 @@ #include #include -//#include "rust/cxx.h" -//#include "../include/ffi.h" #include "backends/trtllm/include/ffi.h" diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index fbc174b2..ef4b2907 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -1,4 +1,4 @@ -pub use backend::TrtLLmBackend; +pub use backend::TensorRtLlmBackend; use crate::backend::GenerationContext; @@ -58,5 +58,8 @@ mod ffi { request_id: u64, callback: fn(Box, u32, u32, bool), ) -> u32; + + #[rust_name = "shutdown"] + fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); } }