misc(offline): update offline tester

This commit is contained in:
Morgan Funtowicz 2024-10-30 22:40:49 +01:00
parent b98c635781
commit 6a5f6b0755
2 changed files with 12 additions and 20 deletions

View File

@ -55,7 +55,7 @@ if (${LLAMA_CPP_BUILD_OFFLINE_RUNNER})
message(STATUS "Building llama.cpp offline runner")
add_executable(tgi_llamacpp_offline_runner offline/main.cpp)
target_link_libraries(tgi_llamacpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common)
target_link_libraries(tgi_llamacpp_offline_runner PUBLIC tgi_llamacpp_backend_impl llama common spdlog::spdlog fmt::fmt)
endif ()

View File

@ -22,27 +22,19 @@ int main(int argc, char **argv) {
const auto prompt = "My name is Morgan";
const auto modelPath = absolute(std::filesystem::path(argv[1]));
if (auto maybeBackend = TgiLlamaCppBackend::FromGGUF(modelPath); maybeBackend.has_value()) {
// Retrieve the backend
auto [model, context] = *maybeBackend;
auto backend = TgiLlamaCppBackend(model, context);
const auto params = llama_model_default_params();
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
// Generate
const auto promptTokens = backend.Tokenize(prompt);
const auto out = backend.Generate(promptTokens, 30, 1.0, 2.0, 0.0, 32);
auto backend = single_worker_backend_t(model, {});
if (out.has_value())
fmt::print(FMT_STRING("Generated: {}"), *out);
else {
const auto err = out.error();
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
}
// generate
const auto promptTokens = {128000, 9906, 856, 836, 374, 23809, 128001};
const auto out = backend.generate(promptTokens, {.max_new_tokens = 32}, {.top_k = 40});
} else {
switch (maybeBackend.error()) {
case TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath);
return maybeBackend.error();
}
if (out.has_value())
fmt::print(FMT_STRING("Generated: {}"), *out);
else {
const auto err = out.error();
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
}
}