mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat(backend): expose frequency and repetition penalties
This commit is contained in:
parent
d4b5be10f9
commit
37faeb34b2
@ -100,8 +100,15 @@ namespace huggingface::tgi::backends::llama {
|
||||
return std::make_unique<llama_sampler*>(sampler);
|
||||
}
|
||||
|
||||
std::vector<TgiLlamaCppBackend::TokenId> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
|
||||
std::span<const TokenId> tokens, const uint32_t topK, const float_t topP, const uint32_t maxNewTokens) {
|
||||
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> huggingface::tgi::backends::llama::TgiLlamaCppBackend::Generate(
|
||||
std::span<const TokenId> tokens,
|
||||
const uint32_t topK,
|
||||
const float_t topP,
|
||||
const float_t frequencyPenalty,
|
||||
const float_t repetitionPenalty,
|
||||
const uint32_t maxNewTokens,
|
||||
const uint64_t seed
|
||||
) {
|
||||
SPDLOG_DEBUG(FMT_STRING("Received {:d} tokens to schedule"), tokens.size());
|
||||
|
||||
// Allocate generation result
|
||||
@ -110,7 +117,7 @@ namespace huggingface::tgi::backends::llama {
|
||||
|
||||
// Retrieve decoding context
|
||||
auto batch = llama_batch_get_one(const_cast<int32_t *>(tokens.data()), static_cast<int32_t>(tokens.size()));
|
||||
auto sampler = GetSamplerFromArgs(topK, topP, 1.0, 1.0, 2014);
|
||||
auto sampler = GetSamplerFromArgs(topK, topP, frequencyPenalty, repetitionPenalty, seed);
|
||||
|
||||
// Decode
|
||||
for(auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
|
||||
|
@ -13,7 +13,7 @@
|
||||
#define LLAMA_SUCCESS(x) x == 0
|
||||
|
||||
namespace huggingface::tgi::backends::llama {
|
||||
enum TgiLlamaCppBackendError {
|
||||
enum TgiLlamaCppBackendError: uint8_t {
|
||||
MODEL_FILE_DOESNT_EXIST = 1
|
||||
};
|
||||
|
||||
@ -43,24 +43,33 @@ namespace huggingface::tgi::backends::llama {
|
||||
* @param text
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
|
||||
[[nodiscard("Tokens will be freed after this call if not assigned to an lvalue")]]
|
||||
std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param tokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param frequencyPenalty
|
||||
* @param repetitionPenalty
|
||||
* @param maxNewTokens
|
||||
* @param seed
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Generate(
|
||||
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||
std::expected<std::vector<TgiLlamaCppBackend::TokenId>, TgiLlamaCppBackendError> Generate(
|
||||
std::span<const TokenId> tokens,
|
||||
uint32_t topK,
|
||||
float_t topP = 1.0f,
|
||||
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max()
|
||||
float_t frequencyPenalty = 0.0f,
|
||||
float_t repetitionPenalty = 0.0f,
|
||||
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1,
|
||||
uint64_t seed = 2014
|
||||
);
|
||||
};
|
||||
|
||||
[[nodiscard("Create backend will be freed after this call if not assigned to an lvalue")]]
|
||||
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
||||
CreateLlamaCppBackend(const std::filesystem::path& root);
|
||||
}
|
||||
|
@ -27,8 +27,15 @@ int main(int argc, char** argv) {
|
||||
|
||||
// Generate
|
||||
const auto promptTokens = backend->Tokenize(prompt);
|
||||
const auto out = backend->Generate(promptTokens, 30, 1.0, 32);
|
||||
fmt::print(FMT_STRING("Generated: {}"), out);
|
||||
const auto out = backend->Generate(promptTokens, 30, 1.0, 2.0, 0.0, 32);
|
||||
|
||||
if(out.has_value())
|
||||
fmt::print(FMT_STRING("Generated: {}"), *out);
|
||||
else {
|
||||
const auto err = out.error();
|
||||
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
|
||||
}
|
||||
|
||||
} else {
|
||||
switch (maybeBackend.error()) {
|
||||
case huggingface::tgi::backends::llama::TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
|
||||
|
Loading…
Reference in New Issue
Block a user