text-generation-inference/backends/trtllm/src/ffi.cpp

71 lines
2.5 KiB
C++
Raw Normal View History

//
// Created by mfuntowicz on 6/30/24.
//
2024-07-11 21:24:32 +00:00
#pragma once
#include <cmath>
#include <filesystem>
2024-07-11 21:24:32 +00:00
#include <vector>
2024-07-12 19:25:40 +00:00
#include "backends/trtllm/include/ffi.h"
2024-07-11 21:24:32 +00:00
2024-07-12 19:25:40 +00:00
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
const std::string_view &engineFolder,
const std::string_view &executorWorker
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
2024-07-11 21:24:32 +00:00
2024-07-12 19:25:40 +00:00
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens,
int32_t maxNewTokens, int32_t topK, float_t topP,
float_t temperature, uint64_t seed) {
// This will copy all the items from the initial slice
std::vector<int32_t> tokens_(tokens.size());
tokens_.assign(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
}
uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream(
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
uint64_t requestId,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, uint32_t, bool)> handler) {
bool isDone = false;
uint32_t numGeneratedTokens = 0;
do {
const auto responses = Poll(requestId);
for (const auto &response: responses) {
if (response.hasError()) {
isDone = true;
// TODO : bubble up the error to rust
} else {
const auto generation = response.getResult();
const auto token = generation.outputTokenIds[0][0];
isDone = generation.isFinal;
// Propagate through the handler
handler(std::move(ctx), token, numGeneratedTokens, isDone);
}
2024-07-11 21:24:32 +00:00
}
2024-07-12 19:25:40 +00:00
} while (!isDone);
2024-07-11 21:24:32 +00:00
2024-07-12 19:25:40 +00:00
return numGeneratedTokens;
}
2024-07-11 21:24:32 +00:00
2024-07-12 19:25:40 +00:00
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
// Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend();
2024-07-12 19:25:40 +00:00
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
}