chore(trtllm): validate there are enough GPus on the system for the desired model

This commit is contained in:
Morgan Funtowicz 2024-10-21 23:40:38 +02:00
parent 848b8ad554
commit a6ac2741a3

View File

@ -88,7 +88,16 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
config(json::parse(std::ifstream(enginesFolder / "config.json"))), config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY, executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) { GetExecutorConfig(config, executorWorker.string())) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());
// Ensure we have enough GPUs on the system
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
if (numGpus < worldSize) {
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
// todo : raise exception to catch on rust side
}
// Cache variables // Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>(); maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();