mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
chore(trtllm): validate there are enough GPus on the system for the desired model
This commit is contained in:
parent
848b8ad554
commit
a6ac2741a3
@ -88,7 +88,16 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||
GetExecutorConfig(config, executorWorker.string())) {
|
||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||
|
||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());
|
||||
|
||||
// Ensure we have enough GPUs on the system
|
||||
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
|
||||
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
|
||||
if (numGpus < worldSize) {
|
||||
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
|
||||
// todo : raise exception to catch on rust side
|
||||
}
|
||||
|
||||
// Cache variables
|
||||
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||
|
Loading…
Reference in New Issue
Block a user