text-generation-inference/backends/trtllm/include/backend.h

100 lines
2.5 KiB
C
Raw Normal View History

//
// Created by Morgan Funtowicz on 6/30/24.
//
#ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H
2024-07-11 21:24:32 +00:00
#include <cmath>
#include <filesystem>
2024-07-03 08:27:53 +00:00
#include <span>
2024-07-11 21:24:32 +00:00
#include <vector>
2024-07-08 22:08:49 +00:00
#include <nlohmann/json.hpp>
2024-07-03 08:27:53 +00:00
#include <tensorrt_llm/runtime/common.h>
#include <tensorrt_llm/executor/executor.h>
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
2024-07-03 08:27:53 +00:00
using json = nlohmann::json;
2024-07-03 08:27:53 +00:00
namespace tle = tensorrt_llm::executor;
namespace huggingface::tgi::backends {
2024-07-11 21:24:32 +00:00
using RequestId = tle::IdType;
using TokenId = tle::TokenIdType;
using TokenStreamingCallback = void(tle::TokenIdType);
2024-07-08 22:08:49 +00:00
/**
* Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine
*/
void InitializeBackend();
/**
*
* @param config TensorRT-LLM configuration object
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
2024-07-08 22:08:49 +00:00
* @return
*/
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
2024-07-08 22:08:49 +00:00
/**
*
*/
class TensorRtLlmBackend {
private:
const json config;
2024-07-03 08:27:53 +00:00
tle::Executor executor;
public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker
);
2024-07-03 08:27:53 +00:00
/***
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
2024-07-11 21:24:32 +00:00
[[nodiscard]] bool IsReady() const;
2024-07-03 08:27:53 +00:00
2024-07-16 20:11:59 +00:00
/***
* Query the executor for the number of token available for pulling
* @return
*/
[[nodiscard]] size_t NumResponsesReady() const;
2024-07-03 08:27:53 +00:00
/***
* Submit a new generation task to the executor
2024-07-03 08:27:53 +00:00
* @param tokens
* @param maxNewTokens
* @param topK
* @param topP
* @param temperature
* @param seed
* @return Request id related to this generation for reference
2024-07-03 08:27:53 +00:00
*/
2024-07-11 21:24:32 +00:00
[[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens,
2024-07-08 22:08:49 +00:00
int32_t topK,
2024-07-03 08:27:53 +00:00
float_t topP,
float_t temperature,
2024-07-11 21:24:32 +00:00
uint64_t seed
2024-07-03 08:27:53 +00:00
);
2024-07-08 22:08:49 +00:00
2024-07-11 21:24:32 +00:00
/***
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector<tle::Response> Poll(RequestId requestId);
2024-07-15 07:36:01 +00:00
/***
* Stop the underlying executor
*/
void Shutdown();
};
}
#endif //TGI_TRTLLM_BACKEND_H