2024-06-30 21:37:20 +00:00
|
|
|
//
|
|
|
|
// Created by Morgan Funtowicz on 6/30/24.
|
|
|
|
//
|
|
|
|
|
|
|
|
#ifndef TGI_TRTLLM_BACKEND_H
|
|
|
|
#define TGI_TRTLLM_BACKEND_H
|
|
|
|
|
|
|
|
#include <filesystem>
|
2024-07-03 08:27:53 +00:00
|
|
|
#include <span>
|
2024-06-30 21:37:20 +00:00
|
|
|
|
2024-07-03 21:12:24 +00:00
|
|
|
#include <fmt/format.h>
|
2024-07-08 22:08:49 +00:00
|
|
|
#include <nlohmann/json.hpp>
|
2024-07-03 21:12:24 +00:00
|
|
|
|
2024-07-03 08:27:53 +00:00
|
|
|
#include <tensorrt_llm/runtime/common.h>
|
|
|
|
#include <tensorrt_llm/executor/executor.h>
|
2024-07-03 21:12:24 +00:00
|
|
|
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
|
2024-07-03 08:27:53 +00:00
|
|
|
|
2024-07-03 21:12:24 +00:00
|
|
|
using json = nlohmann::json;
|
2024-07-03 08:27:53 +00:00
|
|
|
namespace tle = tensorrt_llm::executor;
|
2024-06-30 21:37:20 +00:00
|
|
|
|
|
|
|
namespace huggingface::tgi::backends {
|
2024-07-03 21:12:24 +00:00
|
|
|
|
2024-07-09 13:46:48 +00:00
|
|
|
using TokenStreamingCallback = void(tle::TokenIdType);
|
|
|
|
|
2024-07-08 22:08:49 +00:00
|
|
|
/**
|
|
|
|
* Initialize all the components required by TRTLLM.
|
|
|
|
* It is required to call this function before attempting to load any engine
|
|
|
|
*/
|
|
|
|
void InitializeBackend();
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
2024-07-09 13:46:48 +00:00
|
|
|
* @param config TensorRT-LLM configuration object
|
|
|
|
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
2024-07-08 22:08:49 +00:00
|
|
|
* @return
|
|
|
|
*/
|
2024-07-03 21:12:24 +00:00
|
|
|
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
|
|
|
|
2024-07-08 22:08:49 +00:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
*/
|
2024-07-01 13:53:23 +00:00
|
|
|
class TensorRtLlmBackend {
|
2024-06-30 21:37:20 +00:00
|
|
|
private:
|
2024-07-03 21:12:24 +00:00
|
|
|
const json config;
|
2024-07-03 08:27:53 +00:00
|
|
|
tle::Executor executor;
|
2024-06-30 21:37:20 +00:00
|
|
|
|
|
|
|
public:
|
2024-07-03 21:12:24 +00:00
|
|
|
explicit TensorRtLlmBackend(
|
|
|
|
const std::filesystem::path &engineFolder,
|
|
|
|
const std::filesystem::path &executorWorker
|
|
|
|
);
|
2024-07-03 08:27:53 +00:00
|
|
|
|
|
|
|
/***
|
|
|
|
* Indicate if the backend is ready to accept incoming request
|
|
|
|
* @return true if ready, false otherwise
|
|
|
|
*/
|
|
|
|
[[nodiscard]] bool IsReady() const {
|
|
|
|
return executor.canEnqueueRequests();
|
|
|
|
}
|
|
|
|
|
|
|
|
/***
|
2024-07-09 13:46:48 +00:00
|
|
|
* Submit a new generation task to the executor
|
2024-07-03 08:27:53 +00:00
|
|
|
* @param tokens
|
|
|
|
* @param maxNewTokens
|
|
|
|
* @param topK
|
|
|
|
* @param topP
|
|
|
|
* @param temperature
|
|
|
|
* @param minLength
|
|
|
|
* @param repetitionPenalty
|
2024-07-08 22:08:49 +00:00
|
|
|
* @param frequencyPenalty
|
2024-07-03 08:27:53 +00:00
|
|
|
* @param seed
|
|
|
|
* @param nTopTokens
|
2024-07-09 13:46:48 +00:00
|
|
|
* @return Request id related to this generation for reference
|
2024-07-03 08:27:53 +00:00
|
|
|
*/
|
|
|
|
[[nodiscard]] tle::IdType Submit(
|
2024-07-08 22:08:49 +00:00
|
|
|
const std::vector<tle::TokenIdType> &tokens,
|
2024-07-03 08:27:53 +00:00
|
|
|
int32_t maxNewTokens,
|
2024-07-08 22:08:49 +00:00
|
|
|
int32_t topK,
|
2024-07-03 08:27:53 +00:00
|
|
|
float_t topP,
|
|
|
|
float_t temperature,
|
|
|
|
int32_t minLength,
|
|
|
|
std::optional<float_t> repetitionPenalty = std::nullopt,
|
2024-07-08 22:08:49 +00:00
|
|
|
std::optional<float_t> frequencyPenalty = std::nullopt,
|
2024-07-03 08:27:53 +00:00
|
|
|
std::optional<uint32_t> seed = std::nullopt,
|
|
|
|
std::optional<uint32_t> nTopTokens = std::nullopt
|
|
|
|
);
|
2024-07-08 22:08:49 +00:00
|
|
|
|
|
|
|
/***
|
2024-07-09 13:46:48 +00:00
|
|
|
* Unroll the token generation until end of stream is reached.
|
|
|
|
* Every generated token is streamed back through the provided callback for further processing
|
|
|
|
* @param reqId The request id to unroll
|
|
|
|
* @param cb The callback to stream token back
|
|
|
|
* @return Global number of generated tokens for this request id
|
2024-07-08 22:08:49 +00:00
|
|
|
*/
|
2024-07-09 13:46:48 +00:00
|
|
|
size_t Stream(tle::IdType reqId, const std::function<TokenStreamingCallback>& cb);
|
2024-06-30 21:37:20 +00:00
|
|
|
};
|
|
|
|
}
|
|
|
|
|
2024-07-03 21:12:24 +00:00
|
|
|
|
2024-07-09 13:46:48 +00:00
|
|
|
#endif //TGI_TRTLLM_BACKEND_H
|