mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
chore(trtllm): validate there are enough GPus on the system for the desired model
This commit is contained in:
parent
848b8ad554
commit
a6ac2741a3
@ -88,7 +88,16 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
|||||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||||
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||||
GetExecutorConfig(config, executorWorker.string())) {
|
GetExecutorConfig(config, executorWorker.string())) {
|
||||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
|
||||||
|
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());
|
||||||
|
|
||||||
|
// Ensure we have enough GPUs on the system
|
||||||
|
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
|
||||||
|
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
|
||||||
|
if (numGpus < worldSize) {
|
||||||
|
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
|
||||||
|
// todo : raise exception to catch on rust side
|
||||||
|
}
|
||||||
|
|
||||||
// Cache variables
|
// Cache variables
|
||||||
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||||
|
Loading…
Reference in New Issue
Block a user