mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
(backend) expose PullNewTokens
This commit is contained in:
parent
b8a40a0af3
commit
f4a74be384
@ -56,7 +56,7 @@ namespace huggingface::tgi::backends {
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed
|
||||
);
|
||||
) noexcept;
|
||||
|
||||
/**
|
||||
*
|
||||
@ -72,12 +72,6 @@ namespace huggingface::tgi::backends {
|
||||
const std::filesystem::path &executorWorker
|
||||
);
|
||||
|
||||
/**
|
||||
* Indicate if the backend is ready to accept incoming request
|
||||
* @return true if ready, false otherwise
|
||||
*/
|
||||
[[nodiscard]] bool IsReady() const;
|
||||
|
||||
/**
|
||||
* Query the executor for the number of token available for pulling
|
||||
* @return
|
||||
@ -106,17 +100,7 @@ namespace huggingface::tgi::backends {
|
||||
const uint64_t seed
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param requestId The request id to poll the generation results
|
||||
* @return
|
||||
*/
|
||||
std::vector <tle::Response> Poll(RequestId requestId);
|
||||
|
||||
/**
|
||||
* Stop the underlying executor
|
||||
*/
|
||||
void Shutdown();
|
||||
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -54,18 +54,13 @@ namespace huggingface::tgi::backends {
|
||||
|
||||
/***
|
||||
*
|
||||
* @param requestId
|
||||
* @param ctx
|
||||
* @param callback
|
||||
* @return
|
||||
*/
|
||||
size_t StreamTokens(
|
||||
const RequestId requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback);
|
||||
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
|
||||
};
|
||||
|
||||
GenerationStep ConvertResponseToGenerationStep(const tle::Response &response);
|
||||
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
|
@ -84,18 +84,11 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
const std::filesystem::path &executorWorker
|
||||
) :
|
||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||
executor(
|
||||
enginesFolder,
|
||||
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||
GetExecutorConfig(config, executorWorker.string()
|
||||
)) {
|
||||
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||
GetExecutorConfig(config, executorWorker.string())) {
|
||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||
}
|
||||
|
||||
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
||||
return executor.canEnqueueRequests();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||
return executor.getNumResponsesReady();
|
||||
@ -134,8 +127,6 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
|
||||
}
|
||||
|
||||
|
||||
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
|
||||
SPDLOG_INFO("Shutting down executor");
|
||||
executor.shutdown();
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
|
||||
return std::move(executor.awaitResponses());
|
||||
}
|
||||
|
@ -35,47 +35,42 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
}
|
||||
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||
const uint64_t requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback) {
|
||||
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
|
||||
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
|
||||
const auto responses = TensorRtLlmBackend::PullNewTokens();
|
||||
auto steps = std::make_unique<std::vector<GenerationStep>>(responses.size());
|
||||
std::ranges::copy(std::views::transform(responses, ConvertResponseToGenerationStep), std::back_inserter(*steps));
|
||||
return steps;
|
||||
}
|
||||
|
||||
size_t numTokens = 0;
|
||||
for (const auto &item: Poll(requestId)) {
|
||||
GenerationStep step;
|
||||
if (!item.hasError()) {
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
||||
const auto decoded = item.getResult();
|
||||
|
||||
const auto token = decoded.outputTokenIds[0][0];
|
||||
const auto isFinal = decoded.isFinal;
|
||||
const auto logProb = decoded.logProbs.value()[0][0];
|
||||
|
||||
++numTokens;
|
||||
|
||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
||||
};
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||
huggingface::tgi::backends::GenerationStep
|
||||
huggingface::tgi::backends::ConvertResponseToGenerationStep(const tle::Response &response) {
|
||||
const auto reqId = response.getRequestId();
|
||||
if (!response.hasError()) {
|
||||
const auto result = response.getResult();
|
||||
return std::move(GenerationStep{
|
||||
reqId,
|
||||
result.outputTokenIds[0][0],
|
||||
result.logProbs.value()[0][0],
|
||||
result.isFinal,
|
||||
false,
|
||||
std::string()
|
||||
});
|
||||
} else {
|
||||
// TODO : Return rest::Result with error
|
||||
const auto what = item.getErrorMsg();
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
||||
};
|
||||
return std::move(GenerationStep{
|
||||
reqId,
|
||||
0,
|
||||
0.0,
|
||||
true,
|
||||
true,
|
||||
std::move(response.getErrorMsg())
|
||||
});
|
||||
}
|
||||
|
||||
callback(std::move(ctx), std::move(step));
|
||||
}
|
||||
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||
SPDLOG_INFO("Creating TensorRT-LLM Backend");
|
||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||
InitializeBackend();
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
||||
pub use looper::TensorRtLlmBackendV2;
|
||||
|
||||
mod backend;
|
||||
pub mod errors;
|
||||
mod looper;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
@ -9,6 +9,7 @@ mod ffi {
|
||||
/// Struct used as shared type between rust and C++ to represent the result
|
||||
/// of a single decoding iteration
|
||||
pub struct GenerationStep {
|
||||
request_id: u64,
|
||||
token_id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
@ -16,10 +17,6 @@ mod ffi {
|
||||
error_msg: String,
|
||||
}
|
||||
|
||||
extern "Rust" {
|
||||
type GenerationContext;
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("backends/trtllm/src/ffi.cpp");
|
||||
|
||||
@ -44,10 +41,7 @@ mod ffi {
|
||||
fn CreateTensorRtLlmBackend(
|
||||
engine_folder: &str,
|
||||
executor_worker: &str,
|
||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||
|
||||
// #[rust_name = "is_ready"]
|
||||
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
|
||||
|
||||
#[rust_name = "num_responses_ready"]
|
||||
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||
@ -63,17 +57,11 @@ mod ffi {
|
||||
repetition_penalty: f32,
|
||||
frequency_penalty: f32,
|
||||
seed: u64,
|
||||
) -> u64;
|
||||
) -> Result<u64>;
|
||||
|
||||
#[rust_name = "stream_tokens"]
|
||||
unsafe fn StreamTokens(
|
||||
#[rust_name = "pull_tokens"]
|
||||
fn PullTokens(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
request_id: u64,
|
||||
ctx: *mut GenerationContext,
|
||||
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
||||
) -> usize;
|
||||
|
||||
// #[rust_name = "shutdown"]
|
||||
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user