mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-07 09:52:18 +00:00
misc(offline): update offline tester
This commit is contained in:
parent
b98c635781
commit
6a5f6b0755
@ -55,7 +55,7 @@ if (${LLAMA_CPP_BUILD_OFFLINE_RUNNER})
|
|||||||
message(STATUS "Building llama.cpp offline runner")
|
message(STATUS "Building llama.cpp offline runner")
|
||||||
add_executable(tgi_llamacpp_offline_runner offline/main.cpp)
|
add_executable(tgi_llamacpp_offline_runner offline/main.cpp)
|
||||||
|
|
||||||
target_link_libraries(tgi_llamacpp_offline_runner PUBLIC tgi_llama_cpp_backend_impl llama common)
|
target_link_libraries(tgi_llamacpp_offline_runner PUBLIC tgi_llamacpp_backend_impl llama common spdlog::spdlog fmt::fmt)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,27 +22,19 @@ int main(int argc, char **argv) {
|
|||||||
const auto prompt = "My name is Morgan";
|
const auto prompt = "My name is Morgan";
|
||||||
|
|
||||||
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
||||||
if (auto maybeBackend = TgiLlamaCppBackend::FromGGUF(modelPath); maybeBackend.has_value()) {
|
const auto params = llama_model_default_params();
|
||||||
// Retrieve the backend
|
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
|
||||||
auto [model, context] = *maybeBackend;
|
|
||||||
auto backend = TgiLlamaCppBackend(model, context);
|
|
||||||
|
|
||||||
// Generate
|
auto backend = single_worker_backend_t(model, {});
|
||||||
const auto promptTokens = backend.Tokenize(prompt);
|
|
||||||
const auto out = backend.Generate(promptTokens, 30, 1.0, 2.0, 0.0, 32);
|
|
||||||
|
|
||||||
if (out.has_value())
|
// generate
|
||||||
fmt::print(FMT_STRING("Generated: {}"), *out);
|
const auto promptTokens = {128000, 9906, 856, 836, 374, 23809, 128001};
|
||||||
else {
|
const auto out = backend.generate(promptTokens, {.max_new_tokens = 32}, {.top_k = 40});
|
||||||
const auto err = out.error();
|
|
||||||
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
if (out.has_value())
|
||||||
switch (maybeBackend.error()) {
|
fmt::print(FMT_STRING("Generated: {}"), *out);
|
||||||
case TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST:
|
else {
|
||||||
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Specified file {} doesnt exist", modelPath);
|
const auto err = out.error();
|
||||||
return maybeBackend.error();
|
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user