mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-27 10:20:17 +00:00
feat(backend): use std::ranges to map uint32_t to llama_token
This commit is contained in:
parent
488ba93898
commit
363d5e45de
@ -8,8 +8,8 @@
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <ranges>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
@ -56,9 +56,16 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||
};
|
||||
|
||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token*
|
||||
auto input_tokens_v = std::vector<llama_token>(input_tokens.size());
|
||||
std::memcpy(input_tokens_v.data(), input_tokens.data(), input_tokens.size());
|
||||
static auto as_llama_token = [](const uint32_t x){ return static_cast<llama_token>(x); };
|
||||
|
||||
#ifdef __cpp_lib_ranges_to_container
|
||||
auto input_tokens_v = input_tokens | std::views::transform(as_llama_token) | std::ranges::to<std::vector>();
|
||||
#else
|
||||
auto input_tokens_ = input_tokens | std::views::transform(as_llama_token);
|
||||
auto input_tokens_v = std::vector<llama_token>(input_tokens_.begin(), input_tokens_.end());
|
||||
#endif
|
||||
|
||||
// Defer the generation to the actual worker_t
|
||||
const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v};
|
||||
if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] {
|
||||
return *result;
|
||||
|
Loading…
Reference in New Issue
Block a user