(backend) expose PullNewTokens

This commit is contained in:
Morgan Funtowicz 2024-08-02 22:16:28 +00:00 committed by Morgan Funtowicz
parent b8a40a0af3
commit f4a74be384
5 changed files with 47 additions and 94 deletions

View File

@ -56,7 +56,7 @@ namespace huggingface::tgi::backends {
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed
);
) noexcept;
/**
*
@ -72,12 +72,6 @@ namespace huggingface::tgi::backends {
const std::filesystem::path &executorWorker
);
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/**
* Query the executor for the number of token available for pulling
* @return
@ -106,17 +100,7 @@ namespace huggingface::tgi::backends {
const uint64_t seed
);
/**
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector <tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
};
}

View File

@ -54,18 +54,13 @@ namespace huggingface::tgi::backends {
/***
*
* @param requestId
* @param ctx
* @param callback
* @return
*/
size_t StreamTokens(
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
};
GenerationStep ConvertResponseToGenerationStep(const tle::Response &response);
/***
*
* @param engineFolder

View File

@ -84,18 +84,11 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &executorWorker
) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(
enginesFolder,
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
}
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
return executor.canEnqueueRequests();
}
[[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
return executor.getNumResponsesReady();
@ -134,8 +127,6 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
}
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
return std::move(executor.awaitResponses());
}

View File

@ -35,47 +35,42 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
}
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
const auto responses = TensorRtLlmBackend::PullNewTokens();
auto steps = std::make_unique<std::vector<GenerationStep>>(responses.size());
std::ranges::copy(std::views::transform(responses, ConvertResponseToGenerationStep), std::back_inserter(*steps));
return steps;
}
size_t numTokens = 0;
for (const auto &item: Poll(requestId)) {
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal;
const auto logProb = decoded.logProbs.value()[0][0];
++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
huggingface::tgi::backends::GenerationStep
huggingface::tgi::backends::ConvertResponseToGenerationStep(const tle::Response &response) {
const auto reqId = response.getRequestId();
if (!response.hasError()) {
const auto result = response.getResult();
return std::move(GenerationStep{
reqId,
result.outputTokenIds[0][0],
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
});
} else {
// TODO : Return rest::Result with error
const auto what = item.getErrorMsg();
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
};
return std::move(GenerationStep{
reqId,
0,
0.0,
true,
true,
std::move(response.getErrorMsg())
});
}
callback(std::move(ctx), std::move(step));
}
return numTokens;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
SPDLOG_INFO("Creating TensorRT-LLM Backend");
// Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend();

View File

@ -1,7 +1,7 @@
pub use backend::{GenerationContext, TensorRtLlmBackend};
pub use looper::TensorRtLlmBackendV2;
mod backend;
pub mod errors;
mod looper;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
@ -9,6 +9,7 @@ mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
pub struct GenerationStep {
request_id: u64,
token_id: u32,
log_prob: f32,
is_final: bool,
@ -16,10 +17,6 @@ mod ffi {
error_msg: String,
}
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp");
@ -44,10 +41,7 @@ mod ffi {
fn CreateTensorRtLlmBackend(
engine_folder: &str,
executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
@ -63,17 +57,11 @@ mod ffi {
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) -> u64;
) -> Result<u64>;
#[rust_name = "stream_tokens"]
unsafe fn StreamTokens(
#[rust_name = "pull_tokens"]
fn PullTokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64,
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
}
}