feat(backend): use std::ranges to map uint32_t to llama_token

This commit is contained in:
Morgan Funtowicz 2024-11-13 00:07:59 +01:00
parent 488ba93898
commit 363d5e45de

View File

@ -8,8 +8,8 @@
#include <exception> #include <exception>
#include <filesystem> #include <filesystem>
#include <memory> #include <memory>
#include <ranges>
#include <string_view> #include <string_view>
#include <variant>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
@ -56,9 +56,16 @@ namespace huggingface::tgi::backends::llamacpp {
}; };
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token* // Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token*
auto input_tokens_v = std::vector<llama_token>(input_tokens.size()); static auto as_llama_token = [](const uint32_t x){ return static_cast<llama_token>(x); };
std::memcpy(input_tokens_v.data(), input_tokens.data(), input_tokens.size());
#ifdef __cpp_lib_ranges_to_container
auto input_tokens_v = input_tokens | std::views::transform(as_llama_token) | std::ranges::to<std::vector>();
#else
auto input_tokens_ = input_tokens | std::views::transform(as_llama_token);
auto input_tokens_v = std::vector<llama_token>(input_tokens_.begin(), input_tokens_.end());
#endif
// Defer the generation to the actual worker_t
const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v}; const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v};
if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] { if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] {
return *result; return *result;