From a6ac2741a3c994653781f35d368d1a2aaedf99a5 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 21 Oct 2024 23:40:38 +0200 Subject: [PATCH] chore(trtllm): validate there are enough GPus on the system for the desired model --- backends/trtllm/lib/backend.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 8a4ac4c04..a3ac05343 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -88,7 +88,16 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( config(json::parse(std::ifstream(enginesFolder / "config.json"))), executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY, GetExecutorConfig(config, executorWorker.string())) { - SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref()); + + SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get()); + + // Ensure we have enough GPUs on the system + const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get(); + const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0); + if (numGpus < worldSize) { + SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize); + // todo : raise exception to catch on rust side + } // Cache variables maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get();