updated logic and comment to detect cuda compute capabilities

This commit is contained in:
Morgan Funtowicz 2024-07-09 12:15:41 +00:00
parent bec188ff73
commit 09292b06a0

View File

@ -13,18 +13,17 @@ void huggingface::tgi::backends::InitializeBackend() {
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1);
// TODO : Need to check for >= sm_80 (ampere)
// Get the compute capabilities of the current hardware
nvmlDevice_t device;
int32_t cudaComputeCapabilitiesMajor, cudaComputeCapabilitiesMinor;
int32_t cudaComputeCapabilitiesMajor = 0, cudaComputeCapabilitiesMinor = 0;
if(nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
if(nvmlDeviceGetCudaComputeCapability(device, &cudaComputeCapabilitiesMajor, &cudaComputeCapabilitiesMinor) == NVML_SUCCESS) {
SPDLOG_INFO(FMT_STRING("Detected sm_{:d}{:d} compute capabilities"), cudaComputeCapabilitiesMajor, cudaComputeCapabilitiesMinor);
execConfig.setEnableChunkedContext(cudaComputeCapabilitiesMajor >= 8);
}
}
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
if(config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1){
SPDLOG_INFO("Detected single engine deployment, using leader mode");
execConfig.setParallelConfig(tle::ParallelConfig(
@ -34,7 +33,7 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
std::nullopt,
std::nullopt
));
} else {
} else { // Multiple engines -> using orchestrator mode (MPI involved)
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
execConfig.setParallelConfig(tle::ParallelConfig(
tle::CommunicationType::kMPI,
@ -44,6 +43,10 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
tle::OrchestratorConfig(true, workerPath)
));
}
// Define some configuration variables
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
execConfig.setEnableChunkedContext(cudaComputeCapabilitiesMajor >= 8);
return execConfig;
}