mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
feat(backend): create llama_context_params with default factory
This commit is contained in:
parent
b1ebc8f73b
commit
dc6435e3a5
@ -43,6 +43,15 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
return std::shared_ptr<llama_model>(model, llama_model_deleter);
|
return std::shared_ptr<llama_model>(model, llama_model_deleter);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto get_llama_context_params = [](size_t num_threads) {
|
||||||
|
auto params = llama_context_default_params();
|
||||||
|
params.n_threads = num_threads;
|
||||||
|
params.n_threads_batch = num_threads;
|
||||||
|
params.flash_attn = true;
|
||||||
|
params.no_perf = false;
|
||||||
|
return params;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* llama.cpp backend specific exception mapped from `backend_exception_t` to throw at the FFI level and
|
* llama.cpp backend specific exception mapped from `backend_exception_t` to throw at the FFI level and
|
||||||
* allow automatic implementation of Result<_, Exception> from C++ to Rust
|
* allow automatic implementation of Result<_, Exception> from C++ to Rust
|
||||||
@ -64,7 +73,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
* @param num_threads The number of threads the worker is allowed to spawn accross for its threadpool
|
* @param num_threads The number of threads the worker is allowed to spawn accross for its threadpool
|
||||||
*/
|
*/
|
||||||
explicit llama_cpp_worker_frontend_t(llama_model *model, int32_t num_threads):
|
explicit llama_cpp_worker_frontend_t(llama_model *model, int32_t num_threads):
|
||||||
model_{ make_shared_llama_model(model) }, worker_(model_, {.n_ubatch = 1, .n_threads = num_threads, .no_perf = true}) {}
|
model_{ make_shared_llama_model(model) }, worker_(model_, get_llama_context_params(num_threads)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a new set of tokens from the provided `input_tokens`, streaming each individual token generated
|
* Generate a new set of tokens from the provided `input_tokens`, streaming each individual token generated
|
||||||
|
@ -27,24 +27,37 @@ int main(int argc, char **argv) {
|
|||||||
llama_model_deleter
|
llama_model_deleter
|
||||||
);
|
);
|
||||||
|
|
||||||
auto prompt = "My name is Morgan";
|
auto prompt = std::string("My name is Morgan");
|
||||||
auto tokens = std::vector<llama_token>(16);
|
auto tokens = std::vector<llama_token>(128);
|
||||||
const auto nb_tokens = llama_tokenize(model.get(), prompt, sizeof(prompt), tokens.data(), tokens.size(), true,
|
const auto nb_tokens = llama_tokenize(model.get(), prompt.c_str(), prompt.size(), tokens.data(), tokens.size(),
|
||||||
|
true,
|
||||||
false);
|
false);
|
||||||
tokens.resize(nb_tokens);
|
tokens.resize(nb_tokens);
|
||||||
auto backend = worker_t(std::move(model), {.n_batch = 1, .n_threads = 4});
|
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_DISTRIBUTE);
|
||||||
|
auto backend = worker_t(model, llama_context_default_params());
|
||||||
|
|
||||||
fmt::println("Tokenized: {}", tokens);
|
fmt::println("Tokenized: {}", tokens);
|
||||||
|
|
||||||
// generate
|
// generate
|
||||||
auto generated_tokens = std::vector<llama_token>(32);
|
auto generated_tokens = std::vector<llama_token>(32);
|
||||||
const auto n_generated_tokens = backend.generate(
|
const auto n_generated_tokens = backend.generate(
|
||||||
{{.max_new_tokens = 32}, {.top_k = 40}, tokens},
|
{{.max_new_tokens = 32}, {.top_k = 40, .top_p = 0.95, .temperature = 0.8},
|
||||||
|
tokens},
|
||||||
[&generated_tokens](llama_token new_token_id, float_t logit, bool is_eos, size_t step) -> bool {
|
[&generated_tokens](llama_token new_token_id, float_t logit, bool is_eos, size_t step) -> bool {
|
||||||
generated_tokens.emplace(generated_tokens.begin() + (step - 1), new_token_id);
|
generated_tokens.emplace(generated_tokens.begin() + (step - 1), new_token_id);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
generated_tokens.resize(n_generated_tokens.value());
|
generated_tokens.resize(n_generated_tokens.value());
|
||||||
fmt::println("Generated {} tokens", generated_tokens);
|
|
||||||
|
std::string decoded = std::string(256, 'a');
|
||||||
|
const size_t length = llama_detokenize(model.get(),
|
||||||
|
generated_tokens.data(),
|
||||||
|
generated_tokens.size(),
|
||||||
|
decoded.data(),
|
||||||
|
decoded.size(),
|
||||||
|
false, false);
|
||||||
|
decoded.resize(std::min(length, decoded.size()));
|
||||||
|
fmt::println("Generated tokens: {}", generated_tokens);
|
||||||
|
fmt::println("Generated text: {}", decoded);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user