diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index ee8171bc..6a4a7be7 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -47,25 +47,8 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); // Single engine (TP = PP = 1) -> using leader mode (no MPI involved) - if (config["/pretrained_config/mapping/world_size"_json_pointer].get() == 1) { - SPDLOG_INFO("Detected single engine deployment, using leader mode"); - execConfig.setParallelConfig(tle::ParallelConfig( - tle::CommunicationType::kMPI, - tle::CommunicationMode::kLEADER, - std::nullopt, - std::nullopt, - std::nullopt - )); - } else { // Multiple engines -> using orchestrator mode (MPI involved) - SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); - execConfig.setParallelConfig(tle::ParallelConfig( - tle::CommunicationType::kMPI, - tle::CommunicationMode::kORCHESTRATOR, - std::nullopt, - std::nullopt, - tle::OrchestratorConfig(true, workerPath, nullptr, true) - )); - } + const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get(); + execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath)); // Define some configuration variables execConfig.setKvCacheConfig(tle::KvCacheConfig(true));