From 363d5e45de275b3c2739e2a4f9abad5cfa7e9baa Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 13 Nov 2024 00:07:59 +0100 Subject: [PATCH] feat(backend): use std::ranges to map uint32_t to llama_token --- backends/llamacpp/csrc/ffi.hpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index 70669b7c..948e96a0 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -8,8 +8,8 @@ #include #include #include +#include #include -#include #include @@ -56,9 +56,16 @@ namespace huggingface::tgi::backends::llamacpp { }; // Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token* - auto input_tokens_v = std::vector(input_tokens.size()); - std::memcpy(input_tokens_v.data(), input_tokens.data(), input_tokens.size()); + static auto as_llama_token = [](const uint32_t x){ return static_cast(x); }; +#ifdef __cpp_lib_ranges_to_container + auto input_tokens_v = input_tokens | std::views::transform(as_llama_token) | std::ranges::to(); +#else + auto input_tokens_ = input_tokens | std::views::transform(as_llama_token); + auto input_tokens_v = std::vector(input_tokens_.begin(), input_tokens_.end()); +#endif + + // Defer the generation to the actual worker_t const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v}; if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] { return *result;