(backend) expose PullNewTokens

This commit is contained in:
Morgan Funtowicz 2024-08-02 22:16:28 +00:00 committed by Morgan Funtowicz
parent b8a40a0af3
commit f4a74be384
5 changed files with 47 additions and 94 deletions

View File

@ -56,7 +56,7 @@ namespace huggingface::tgi::backends {
const float_t repetition_penalty, const float_t repetition_penalty,
const float_t frequency_penalty, const float_t frequency_penalty,
const uint64_t seed const uint64_t seed
); ) noexcept;
/** /**
* *
@ -72,12 +72,6 @@ namespace huggingface::tgi::backends {
const std::filesystem::path &executorWorker const std::filesystem::path &executorWorker
); );
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/** /**
* Query the executor for the number of token available for pulling * Query the executor for the number of token available for pulling
* @return * @return
@ -106,17 +100,7 @@ namespace huggingface::tgi::backends {
const uint64_t seed const uint64_t seed
); );
/** [[nodiscard]] std::vector<tle::Response> PullNewTokens();
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector <tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
}; };
} }

View File

@ -54,18 +54,13 @@ namespace huggingface::tgi::backends {
/*** /***
* *
* @param requestId
* @param ctx
* @param callback
* @return * @return
*/ */
size_t StreamTokens( std::unique_ptr<std::vector<GenerationStep>> PullTokens();
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
}; };
GenerationStep ConvertResponseToGenerationStep(const tle::Response &response);
/*** /***
* *
* @param engineFolder * @param engineFolder

View File

@ -84,18 +84,11 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &executorWorker const std::filesystem::path &executorWorker
) : ) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))), config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor( executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
enginesFolder, GetExecutorConfig(config, executorWorker.string())) {
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>()); SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
} }
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
return executor.canEnqueueRequests();
}
[[nodiscard("Returned number of requests needs to be consumed")]] [[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
return executor.getNumResponsesReady(); return executor.getNumResponsesReady();
@ -134,8 +127,6 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG}); return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
} }
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { return std::move(executor.awaitResponses());
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
} }

View File

@ -35,47 +35,42 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
} }
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
const uint64_t requestId, huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
huggingface::tgi::backends::GenerationContext *ctx, const auto responses = TensorRtLlmBackend::PullNewTokens();
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, auto steps = std::make_unique<std::vector<GenerationStep>>(responses.size());
huggingface::tgi::backends::GenerationStep)> callback) { std::ranges::copy(std::views::transform(responses, ConvertResponseToGenerationStep), std::back_inserter(*steps));
return steps;
}
size_t numTokens = 0; huggingface::tgi::backends::GenerationStep
for (const auto &item: Poll(requestId)) { huggingface::tgi::backends::ConvertResponseToGenerationStep(const tle::Response &response) {
GenerationStep step; const auto reqId = response.getRequestId();
if (!item.hasError()) { if (!response.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token..."); const auto result = response.getResult();
const auto decoded = item.getResult(); return std::move(GenerationStep{
reqId,
const auto token = decoded.outputTokenIds[0][0]; result.outputTokenIds[0][0],
const auto isFinal = decoded.isFinal; result.logProbs.value()[0][0],
const auto logProb = decoded.logProbs.value()[0][0]; result.isFinal,
false,
++numTokens; std::string()
});
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); } else {
step = huggingface::tgi::backends::GenerationStep{ return std::move(GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string()) reqId,
}; 0,
SPDLOG_DEBUG("\tStreamTokens -> Post callback"); 0.0,
} else { true,
// TODO : Return rest::Result with error true,
const auto what = item.getErrorMsg(); std::move(response.getErrorMsg())
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); });
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
};
}
callback(std::move(ctx), std::move(step));
} }
return numTokens;
} }
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl> std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
SPDLOG_INFO("Creating TensorRT-LLM Backend");
// Unconditionally call this to initialize and discover TRTLLM plugins // Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend(); InitializeBackend();

View File

@ -1,7 +1,7 @@
pub use backend::{GenerationContext, TensorRtLlmBackend}; pub use looper::TensorRtLlmBackendV2;
mod backend;
pub mod errors; pub mod errors;
mod looper;
#[cxx::bridge(namespace = "huggingface::tgi::backends")] #[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi { mod ffi {
@ -9,6 +9,7 @@ mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
pub struct GenerationStep { pub struct GenerationStep {
request_id: u64,
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
@ -16,10 +17,6 @@ mod ffi {
error_msg: String, error_msg: String,
} }
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" { unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp"); include!("backends/trtllm/src/ffi.cpp");
@ -44,10 +41,7 @@ mod ffi {
fn CreateTensorRtLlmBackend( fn CreateTensorRtLlmBackend(
engine_folder: &str, engine_folder: &str,
executor_worker: &str, executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>; ) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
#[rust_name = "num_responses_ready"] #[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
@ -63,17 +57,11 @@ mod ffi {
repetition_penalty: f32, repetition_penalty: f32,
frequency_penalty: f32, frequency_penalty: f32,
seed: u64, seed: u64,
) -> u64; ) -> Result<u64>;
#[rust_name = "stream_tokens"] #[rust_name = "pull_tokens"]
unsafe fn StreamTokens( fn PullTokens(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64, ) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
} }
} }