mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-10-09 15:05:24 +00:00
working setup of the ffi layer
This commit is contained in:
parent
5aede911f8
commit
50e9fc89c8
@ -5,8 +5,10 @@
|
|||||||
#ifndef TGI_TRTLLM_BACKEND_H
|
#ifndef TGI_TRTLLM_BACKEND_H
|
||||||
#define TGI_TRTLLM_BACKEND_H
|
#define TGI_TRTLLM_BACKEND_H
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <span>
|
#include <span>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <spdlog/fmt/fmt.h>
|
#include <spdlog/fmt/fmt.h>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
@ -19,7 +21,8 @@ using json = nlohmann::json;
|
|||||||
namespace tle = tensorrt_llm::executor;
|
namespace tle = tensorrt_llm::executor;
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
|
using RequestId = tle::IdType;
|
||||||
|
using TokenId = tle::TokenIdType;
|
||||||
using TokenStreamingCallback = void(tle::TokenIdType);
|
using TokenStreamingCallback = void(tle::TokenIdType);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -54,9 +57,7 @@ namespace huggingface::tgi::backends {
|
|||||||
* Indicate if the backend is ready to accept incoming request
|
* Indicate if the backend is ready to accept incoming request
|
||||||
* @return true if ready, false otherwise
|
* @return true if ready, false otherwise
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] bool IsReady() const {
|
[[nodiscard]] bool IsReady() const;
|
||||||
return executor.canEnqueueRequests();
|
|
||||||
}
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* Submit a new generation task to the executor
|
* Submit a new generation task to the executor
|
||||||
@ -65,26 +66,25 @@ namespace huggingface::tgi::backends {
|
|||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
* @param minLength
|
|
||||||
* @param repetitionPenalty
|
|
||||||
* @param frequencyPenalty
|
|
||||||
* @param seed
|
* @param seed
|
||||||
* @param nTopTokens
|
|
||||||
* @return Request id related to this generation for reference
|
* @return Request id related to this generation for reference
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] tle::IdType Submit(
|
[[nodiscard]] RequestId Submit(
|
||||||
const std::vector<tle::TokenIdType> &tokens,
|
const std::vector<TokenId> &tokens,
|
||||||
int32_t maxNewTokens,
|
int32_t maxNewTokens,
|
||||||
int32_t topK,
|
int32_t topK,
|
||||||
float_t topP,
|
float_t topP,
|
||||||
float_t temperature,
|
float_t temperature,
|
||||||
int32_t minLength,
|
uint64_t seed
|
||||||
std::optional<float_t> repetitionPenalty = std::nullopt,
|
|
||||||
std::optional<float_t> frequencyPenalty = std::nullopt,
|
|
||||||
std::optional<uint32_t> seed = std::nullopt,
|
|
||||||
std::optional<uint32_t> nTopTokens = std::nullopt
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param requestId The request id to poll the generation results
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::vector<tle::Response> Poll(RequestId requestId);
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* Unroll the token generation until end of stream is reached.
|
* Unroll the token generation until end of stream is reached.
|
||||||
* Every generated token is streamed back through the provided callback for further processing
|
* Every generated token is streamed back through the provided callback for further processing
|
||||||
@ -92,7 +92,7 @@ namespace huggingface::tgi::backends {
|
|||||||
* @param cb The callback to stream token back
|
* @param cb The callback to stream token back
|
||||||
* @return Global number of generated tokens for this request id
|
* @return Global number of generated tokens for this request id
|
||||||
*/
|
*/
|
||||||
size_t Stream(tle::IdType reqId, const std::function<TokenStreamingCallback> &cb);
|
uint32_t Stream(RequestId reqId, std::function<TokenStreamingCallback> &cb);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include <fmt/std.h>
|
#include <fstream>
|
||||||
|
|
||||||
#include <nvml.h>
|
#include <nvml.h>
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
@ -17,15 +18,17 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
|
|||||||
// Get the compute capabilities of the current hardware
|
// Get the compute capabilities of the current hardware
|
||||||
nvmlDevice_t device;
|
nvmlDevice_t device;
|
||||||
int32_t cudaComputeCapabilitiesMajor = 0, cudaComputeCapabilitiesMinor = 0;
|
int32_t cudaComputeCapabilitiesMajor = 0, cudaComputeCapabilitiesMinor = 0;
|
||||||
if(nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
|
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
|
||||||
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
|
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
|
||||||
if(nvmlDeviceGetCudaComputeCapability(device, &cudaComputeCapabilitiesMajor, &cudaComputeCapabilitiesMinor) == NVML_SUCCESS) {
|
if (nvmlDeviceGetCudaComputeCapability(device, &cudaComputeCapabilitiesMajor, &cudaComputeCapabilitiesMinor) ==
|
||||||
SPDLOG_INFO(FMT_STRING("Detected sm_{:d}{:d} compute capabilities"), cudaComputeCapabilitiesMajor, cudaComputeCapabilitiesMinor);
|
NVML_SUCCESS) {
|
||||||
|
SPDLOG_INFO(FMT_STRING("Detected sm_{:d}{:d} compute capabilities"), cudaComputeCapabilitiesMajor,
|
||||||
|
cudaComputeCapabilitiesMinor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
||||||
if(config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1){
|
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
|
||||||
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||||
tle::CommunicationType::kMPI,
|
tle::CommunicationType::kMPI,
|
||||||
@ -54,15 +57,18 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
|
|||||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||||
const std::filesystem::path &enginesFolder,
|
const std::filesystem::path &enginesFolder,
|
||||||
const std::filesystem::path &executorWorker
|
const std::filesystem::path &executorWorker
|
||||||
):
|
) :
|
||||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||||
executor(
|
executor(
|
||||||
enginesFolder,
|
enginesFolder,
|
||||||
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||||
GetExecutorConfig(config, executorWorker.string()
|
GetExecutorConfig(config, executorWorker.string()
|
||||||
))
|
)) {
|
||||||
{
|
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string&>());
|
}
|
||||||
|
|
||||||
|
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
||||||
|
return executor.canEnqueueRequests();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||||
@ -72,11 +78,7 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||||||
const int32_t topK,
|
const int32_t topK,
|
||||||
const float_t topP,
|
const float_t topP,
|
||||||
const float_t temperature,
|
const float_t temperature,
|
||||||
const int32_t minLength,
|
const uint64_t seed
|
||||||
std::optional<float_t> repetitionPenalty,
|
|
||||||
std::optional<float_t> frequencyPenalty,
|
|
||||||
std::optional<uint32_t> seed,
|
|
||||||
std::optional<uint32_t> nTopTokens
|
|
||||||
) {
|
) {
|
||||||
SPDLOG_DEBUG(
|
SPDLOG_DEBUG(
|
||||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||||
@ -92,27 +94,23 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
seed,
|
seed,
|
||||||
|
std::nullopt,
|
||||||
temperature,
|
temperature,
|
||||||
minLength,
|
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
repetitionPenalty,
|
|
||||||
std::nullopt,
|
|
||||||
frequencyPenalty,
|
|
||||||
};
|
};
|
||||||
const auto output = tle::OutputConfig{false, false, nTopTokens.value_or(1) > 1};
|
const auto output = tle::OutputConfig{false, false, false};
|
||||||
const auto request = tle::Request{tokens, maxNewTokens, true, sampling, output};
|
return executor.enqueueRequest(tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||||
|
|
||||||
return executor.enqueueRequest(request);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId, const std::function<TokenStreamingCallback>& cb) {
|
uint32_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId,
|
||||||
|
std::function<TokenStreamingCallback> &cb) {
|
||||||
bool isFinal = false;
|
bool isFinal = false;
|
||||||
size_t generatedTokens = 0;
|
size_t generatedTokens = 0;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
const auto responses = executor.awaitResponses(reqId);
|
const auto responses = executor.awaitResponses(reqId);
|
||||||
for (const auto &response: responses){
|
for (const auto &response: responses) {
|
||||||
if(response.hasError()) {
|
if (response.hasError()) {
|
||||||
SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg());
|
SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg());
|
||||||
isFinal = true;
|
isFinal = true;
|
||||||
} else {
|
} else {
|
||||||
@ -128,8 +126,12 @@ size_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} while(!isFinal);
|
} while (!isFinal);
|
||||||
|
|
||||||
// Return the number of generated tokens
|
// Return the number of generated tokens
|
||||||
return generatedTokens;
|
return generatedTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
||||||
|
return executor.awaitResponses(requestId);
|
||||||
|
}
|
||||||
|
@ -1,20 +1,24 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackend};
|
use crate::ffi::{create_trtllm_backend, TensorRtLlmBackendImpl};
|
||||||
|
|
||||||
|
struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
||||||
|
|
||||||
pub struct TrtLLmBackend {
|
pub struct TrtLLmBackend {
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
inner: UniquePtr<TensorRtLlmBackend>,
|
inner: RefCell<UniquePtr<TensorRtLlmBackendImpl>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Sync for TrtLLmBackend {}
|
unsafe impl Sync for TrtLLmBackend {}
|
||||||
@ -26,9 +30,12 @@ impl TrtLLmBackend {
|
|||||||
engine_folder: P,
|
engine_folder: P,
|
||||||
) -> Result<Self, TensorRtLlmBackendError> {
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
let engine_folder = engine_folder.as_ref();
|
let engine_folder = engine_folder.as_ref();
|
||||||
let inner = create_trtllm_backend(engine_folder.to_str().unwrap());
|
let inner = create_trtllm_backend(engine_folder.to_str().unwrap(), "");
|
||||||
|
|
||||||
Ok(Self { tokenizer, inner })
|
Ok(Self {
|
||||||
|
tokenizer,
|
||||||
|
inner: RefCell::new(inner),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,12 +46,91 @@ impl Backend for TrtLLmBackend {
|
|||||||
request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
let (sender, receiver) = mpsc::unbounded_channel();
|
let (sender, receiver) = mpsc::unbounded_channel();
|
||||||
let request_id = self.inner.submit();
|
let ctx = Box::new(GenerationContext(sender));
|
||||||
|
|
||||||
|
// Unpack parameters
|
||||||
|
let params = request.parameters;
|
||||||
|
|
||||||
|
// Currently we handle single chunk of text
|
||||||
|
if request.inputs.len() == 1 {
|
||||||
|
match request
|
||||||
|
.inputs
|
||||||
|
.first()
|
||||||
|
.expect("Failed to access the first chunk")
|
||||||
|
{
|
||||||
|
Chunk::Text(text) => {
|
||||||
|
let encoding = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(&**text, true)
|
||||||
|
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||||
|
|
||||||
|
let _start = Instant::now();
|
||||||
|
let _request_id = self
|
||||||
|
.inner
|
||||||
|
.borrow_mut()
|
||||||
|
.as_mut()
|
||||||
|
.expect("Failed to retrieve pointer to TRTLLM backend")
|
||||||
|
.submit(
|
||||||
|
encoding.get_ids(),
|
||||||
|
128,
|
||||||
|
params.top_k as i32,
|
||||||
|
params.top_p,
|
||||||
|
params.temperature,
|
||||||
|
params.seed,
|
||||||
|
);
|
||||||
|
|
||||||
|
// spawn_blocking(|| {
|
||||||
|
// // Stream generated tokens
|
||||||
|
// let num_generated_tokens = self
|
||||||
|
// .inner
|
||||||
|
// .borrow_mut()
|
||||||
|
// .as_mut()
|
||||||
|
// .expect("Failed to retrieve pointer to TRTLLM backend")
|
||||||
|
// .stream(request_id, ctx, |token, step, is_final| {
|
||||||
|
// // self.tokenizer.decode(&*[token], true).unwrap();
|
||||||
|
// let token = Token {
|
||||||
|
// id: token,
|
||||||
|
// text: String::from(""),
|
||||||
|
// logprob: 1.0f32,
|
||||||
|
// special: false,
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// sender
|
||||||
|
// .send(Ok(InferStreamResponse::Intermediate {
|
||||||
|
// token,
|
||||||
|
// top_tokens: vec![],
|
||||||
|
// }))
|
||||||
|
// .unwrap()
|
||||||
|
// });
|
||||||
|
//
|
||||||
|
// // Notify the end
|
||||||
|
// Ok(InferStreamResponse::End {
|
||||||
|
// token: Token {
|
||||||
|
// id: 0,
|
||||||
|
// text: String::from(""),
|
||||||
|
// logprob: 1.0f32,
|
||||||
|
// special: false,
|
||||||
|
// },
|
||||||
|
// top_tokens: vec![],
|
||||||
|
// generated_text: GeneratedText {
|
||||||
|
// text: String::from(""),
|
||||||
|
// generated_tokens: num_generated_tokens,
|
||||||
|
// finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
|
// seed: Some(params.seed),
|
||||||
|
// },
|
||||||
|
// start,
|
||||||
|
// queued: Instant::now(),
|
||||||
|
// })
|
||||||
|
// });
|
||||||
|
}
|
||||||
|
Chunk::Image(_) => {}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Ok(UnboundedReceiverStream::new(receiver))
|
Ok(UnboundedReceiverStream::new(receiver))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _current_health: bool) -> bool {
|
async fn health(&self, _current_health: bool) -> bool {
|
||||||
self.inner.is_ready()
|
self.inner.borrow_mut().is_ready()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,23 +1,99 @@
|
|||||||
//
|
//
|
||||||
// Created by mfuntowicz on 6/30/24.
|
// Created by mfuntowicz on 6/30/24.
|
||||||
//
|
//
|
||||||
#include <filesystem>
|
#pragma once
|
||||||
#include "rust/cxx.h"
|
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "rust/cxx.h"
|
||||||
#include "backends/trtllm/include/backend.h"
|
#include "backends/trtllm/include/backend.h"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
|
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
|
||||||
|
public:
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param engineFolder
|
||||||
|
* @param executorWorker
|
||||||
|
*/
|
||||||
|
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker) :
|
||||||
|
TensorRtLlmBackend(std::move(engineFolder), std::move(executorWorker)) {}
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
bool IsReady() const { return TensorRtLlmBackend::IsReady(); }
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param tokens
|
||||||
|
* @param maxNewTokens
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @param temperature
|
||||||
|
* @param seed
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
|
RequestId Submit(rust::Slice<const uint32_t> tokens,
|
||||||
|
int32_t maxNewTokens,
|
||||||
|
int32_t topK,
|
||||||
|
float_t topP,
|
||||||
|
float_t temperature,
|
||||||
|
uint64_t seed) {
|
||||||
|
// This will copy all the items from the initial slice
|
||||||
|
std::vector<int32_t> tokens_(tokens.size());
|
||||||
|
tokens_.assign(tokens.begin(), tokens.end());
|
||||||
|
|
||||||
|
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param requestId
|
||||||
|
* @param handler
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
// uint32_t
|
||||||
|
// Stream(RequestId requestId, rust::Box <GenerationContext>, rust::Fn<void(uint32_t, uint32_t, bool)> handler) {
|
||||||
|
// bool isDone = false;
|
||||||
|
// uint32_t numGeneratedTokens = 0;
|
||||||
|
//
|
||||||
|
// do {
|
||||||
|
// const auto responses = Poll(requestId);
|
||||||
|
// for (const auto &response: responses) {
|
||||||
|
// if (response.hasError()) {
|
||||||
|
// isDone = true;
|
||||||
|
// // TODO : bubble up the error to rust
|
||||||
|
// } else {
|
||||||
|
// const auto generation = response.getResult();
|
||||||
|
// const auto token = generation.outputTokenIds[0][0];
|
||||||
|
// isDone = generation.isFinal;
|
||||||
|
//
|
||||||
|
// // Propagate through the handler
|
||||||
|
// handler(token, numGeneratedTokens, isDone);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// } while (!isDone);
|
||||||
|
//
|
||||||
|
// return numGeneratedTokens;
|
||||||
|
// }
|
||||||
|
};
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param engineFolder
|
* @param engineFolder
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
std::unique_ptr<TensorRtLlmBackend> create_trtllm_backend(rust::Str engineFolder, rust::Str executorWorker) {
|
std::unique_ptr<TensorRtLlmBackendImpl> create_trtllm_backend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||||
InitializeBackend();
|
InitializeBackend();
|
||||||
|
|
||||||
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||||
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
||||||
return std::make_unique<TensorRtLlmBackend>(std::move(enginePath), std::move(executorPath));
|
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -8,7 +8,8 @@ mod ffi {
|
|||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/trtllm/src/ffi.cpp");
|
include!("backends/trtllm/src/ffi.cpp");
|
||||||
|
|
||||||
type TensorRtLlmBackend;
|
/// Represent an instance of the underlying TensorRT-LLM backend
|
||||||
|
type TensorRtLlmBackendImpl;
|
||||||
|
|
||||||
/// Create an instance backed behind an std::unique_ptr to manage the lifespan of the backend
|
/// Create an instance backed behind an std::unique_ptr to manage the lifespan of the backend
|
||||||
///
|
///
|
||||||
@ -24,12 +25,31 @@ mod ffi {
|
|||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
fn create_trtllm_backend(engine_folder: &str, executor_worker: &str) -> UniquePtr<TensorRtLlmBackend>;
|
fn create_trtllm_backend(
|
||||||
|
engine_folder: &str,
|
||||||
|
executor_worker: &str,
|
||||||
|
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||||
|
|
||||||
#[rust_name = "is_ready"]
|
#[rust_name = "is_ready"]
|
||||||
fn IsReady(&self) -> bool;
|
fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||||
|
|
||||||
#[rust_name = "submit"]
|
#[rust_name = "submit"]
|
||||||
fn Submit(&self) -> u64;
|
fn Submit(
|
||||||
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
|
tokens: &[u32],
|
||||||
|
max_new_tokens: i32,
|
||||||
|
top_k: i32,
|
||||||
|
top_p: f32,
|
||||||
|
temperature: f32,
|
||||||
|
seed: u64,
|
||||||
|
) -> u64;
|
||||||
|
|
||||||
|
// #[rust_name = "stream"]
|
||||||
|
// fn Stream(
|
||||||
|
// self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
|
// request_id: u64,
|
||||||
|
// ctx: Box<GenerationContext>,
|
||||||
|
// callback: fn(u32, u32, bool),
|
||||||
|
// ) -> u32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user