mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat(backend): expose frequency and repetition penalties
This commit is contained in:
parent
d4b5be10f9
commit
37faeb34b2
@ -100,8 +100,15 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
return std::make_unique<llama_sampler*>(sampler);
|
return std::make_unique<llama_sampler*>(sampler);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<TgiLlamaCppBackend::TokenId> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
|
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
|
||||||
std::span<const TokenId> tokens, const uint32_t topK, const float_t topP, const uint32_t maxNewTokens) {
|
std::span<const TokenId> tokens,
|
||||||
|
const uint32_t topK,
|
||||||
|
const float_t topP,
|
||||||
|
const float_t frequencyPenalty,
|
||||||
|
const float_t repetitionPenalty,
|
||||||
|
const uint32_t maxNewTokens,
|
||||||
|
const uint64_t seed
|
||||||
|
) {
|
||||||
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
|
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
|
||||||
|
|
||||||
// Allocate generation result
|
// Allocate generation result
|
||||||
@ -110,7 +117,7 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
|
|
||||||
// Retrieve decoding context
|
// Retrieve decoding context
|
||||||
auto batch = llama_batch_get_one(const_cast<int32_t *>(tokens.data()), static_cast<int32_t>(tokens.size()));
|
auto batch = llama_batch_get_one(const_cast<int32_t *>(tokens.data()), static_cast<int32_t>(tokens.size()));
|
||||||
auto sampler = GetSamplerFromArgs(topK, topP, 1.0, 1.0, 2014);
|
auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed);
|
||||||
|
|
||||||
// Decode
|
// Decode
|
||||||
for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
|
for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
#define LLAMA_SUCCESS(x) x == 0
|
#define LLAMA_SUCCESS(x) x == 0
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llama {
|
namespace huggingface::tgi::backends::llama {
|
||||||
enum TgiLlamaCppBackendError {
|
enum TgiLlamaCppBackendError: uint8_t {
|
||||||
MODEL_FILE_DOESNT_EXIST = 1
|
MODEL_FILE_DOESNT_EXIST = 1
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -43,24 +43,33 @@ namespace huggingface::tgi::backends::llama {
|
|||||||
* @param text
|
* @param text
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
|
[[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]]
|
||||||
|
std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param tokens
|
* @param tokens
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
|
* @param frequencyPenalty
|
||||||
|
* @param repetitionPenalty
|
||||||
* @param maxNewTokens
|
* @param maxNewTokens
|
||||||
|
* @param seed
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Generate(
|
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||||
|
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> Generate(
|
||||||
std::span<const TokenId> tokens,
|
std::span<const TokenId> tokens,
|
||||||
uint32_t topK,
|
uint32_t topK,
|
||||||
float_t topP = 1.0f,
|
float_t topP = 1.0f,
|
||||||
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max()
|
float_t frequencyPenalty = 0.0f,
|
||||||
|
float_t repetitionPenalty = 0.0f,
|
||||||
|
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1,
|
||||||
|
uint64_t seed = 2014
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
[[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]]
|
||||||
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
||||||
CreateLlamaCppBackend(const std::filesystem::path& root);
|
CreateLlamaCppBackend(const std::filesystem::path& root);
|
||||||
}
|
}
|
||||||
|
@ -27,8 +27,15 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
// Generate
|
// Generate
|
||||||
const auto promptTokens = backend->Tokenize(prompt);
|
const auto promptTokens = backend->Tokenize(prompt);
|
||||||
const auto out = backend->Generate(promptTokens, 30, 1.0, 32);
|
const auto out = backend->Generate(promptTokens, 30, 1.0, 2.0, 0.0, 32);
|
||||||
fmt::print(FMT_STRING("Generated: {}"), out);
|
|
||||||
|
if(out.has_value())
|
||||||
|
fmt::print(FMT_STRING("Generated: {}"), *out);
|
||||||
|
else {
|
||||||
|
const auto err = out.error();
|
||||||
|
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
switch (maybeBackend.error()) {
|
switch (maybeBackend.error()) {
|
||||||
case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
|
case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
|
||||||
|
Loading…
Reference in New Issue
Block a user