feat(backend): expose frequency and repetition penalties

This commit is contained in:
Morgan Funtowicz 2024-10-23 14:12:52 +02:00
parent d4b5be10f9
commit 37faeb34b2
3 changed files with 32 additions and 9 deletions

View File

@ -100,8 +100,15 @@ namespace huggingface::tgi::backends::llama {
return std::make_unique<llama_sampler*>(sampler);
}
std::vector<TgiLlamaCppBackend::TokenId> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
std::span<const TokenId> tokens, const uint32_t topK, const float_t topP, const uint32_t maxNewTokens) {
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
std::span<const TokenId> tokens,
const uint32_t topK,
const float_t topP,
const float_t frequencyPenalty,
const float_t repetitionPenalty,
const uint32_t maxNewTokens,
const uint64_t seed
) {
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
// Allocate generation result
@ -110,7 +117,7 @@ namespace huggingface::tgi::backends::llama {
// Retrieve decoding context
auto batch = llama_batch_get_one(const_cast<int32_t *>(tokens.data()), static_cast<int32_t>(tokens.size()));
auto sampler = GetSamplerFromArgs(topK, topP, 1.0, 1.0, 2014);
auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed);
// Decode
for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {

View File

@ -13,7 +13,7 @@
#define LLAMA_SUCCESS(x) x == 0
namespace huggingface::tgi::backends::llama {
enum TgiLlamaCppBackendError {
enum TgiLlamaCppBackendError: uint8_t {
MODEL_FILE_DOESNT_EXIST = 1
};
@ -43,24 +43,33 @@ namespace huggingface::tgi::backends::llama {
* @param text
* @return
*/
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
[[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]]
std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
/**
*
* @param tokens
* @param topK
* @param topP
* @param frequencyPenalty
* @param repetitionPenalty
* @param maxNewTokens
* @param seed
* @return
*/
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Generate(
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> Generate(
std::span<const TokenId> tokens,
uint32_t topK,
float_t topP = 1.0f,
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max()
float_t frequencyPenalty = 0.0f,
float_t repetitionPenalty = 0.0f,
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1,
uint64_t seed = 2014
);
};
[[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]]
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
CreateLlamaCppBackend(const std::filesystem::path& root);
}

View File

@ -27,8 +27,15 @@ int main(int argc, char** argv) {
// Generate
const auto promptTokens = backend->Tokenize(prompt);
const auto out = backend->Generate(promptTokens, 30, 1.0, 32);
fmt::print(FMT_STRING("Generated: {}"), out);
const auto out = backend->Generate(promptTokens, 30, 1.0, 2.0, 0.0, 32);
if(out.has_value())
fmt::print(FMT_STRING("Generated: {}"), *out);
else {
const auto err = out.error();
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
}
} else {
switch (maybeBackend.error()) {
case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST: