diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index f283b2ac..1f6dcfae 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -100,8 +100,15 @@ namespace huggingface::tgi::backends::llama { return std::make_unique(sampler); } - std::vector huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate( - std::span tokens, const uint32_t topK, const float_t topP, const uint32_t maxNewTokens) { + std::expected, TgiLlamaCppBackendError> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate( + std::span tokens, + const uint32_t topK, + const float_t topP, + const float_t frequencyPenalty, + const float_t repetitionPenalty, + const uint32_t maxNewTokens, + const uint64_t seed + ) { SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size()); // Allocate generation result @@ -110,7 +117,7 @@ namespace huggingface::tgi::backends::llama { // Retrieve decoding context auto batch = llama_batch_get_one(const_cast(tokens.data()), static_cast(tokens.size())); - auto sampler = GetSamplerFromArgs(topK, topP, 1.0, 1.0, 2014); + auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed); // Decode for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) { diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 26d690c8..5f356bc0 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -13,7 +13,7 @@ #define LLAMA_SUCCESS(x) x == 0 namespace huggingface::tgi::backends::llama { - enum TgiLlamaCppBackendError { + enum TgiLlamaCppBackendError: uint8_t { MODEL_FILE_DOESNT_EXIST = 1 }; @@ -43,24 +43,33 @@ namespace huggingface::tgi::backends::llama { * @param text * @return */ - [[nodiscard]] std::vector Tokenize(const std::string& text) const; + [[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]] + std::vector Tokenize(const std::string& text) const; /** * * @param tokens * @param topK * @param topP + * @param frequencyPenalty + * @param repetitionPenalty * @param maxNewTokens + * @param seed * @return */ - [[nodiscard]] std::vector Generate( + [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] + std::expected, TgiLlamaCppBackendError> Generate( std::span tokens, uint32_t topK, float_t topP = 1.0f, - uint32_t maxNewTokens = std::numeric_limits::max() + float_t frequencyPenalty = 0.0f, + float_t repetitionPenalty = 0.0f, + uint32_t maxNewTokens = std::numeric_limits::max() - 1, + uint64_t seed = 2014 ); }; + [[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]] std::expected, TgiLlamaCppBackendError> CreateLlamaCppBackend(const std::filesystem::path& root); } diff --git a/backends/llamacpp/offline/main.cpp b/backends/llamacpp/offline/main.cpp index 3165261f..c2ae05c7 100644 --- a/backends/llamacpp/offline/main.cpp +++ b/backends/llamacpp/offline/main.cpp @@ -27,8 +27,15 @@ int main(int argc, char** argv) { // Generate const auto promptTokens = backend->Tokenize(prompt); - const auto out = backend->Generate(promptTokens, 30, 1.0, 32); - fmt::print(FMT_STRING("Generated: {}"), out); + const auto out = backend->Generate(promptTokens, 30, 1.0, 2.0, 0.0, 32); + + if(out.has_value()) + fmt::print(FMT_STRING("Generated: {}"), *out); + else { + const auto err = out.error(); + fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast(err)); + } + } else { switch (maybeBackend.error()) { case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST: