diff --git a/server/requirements.txt b/server/requirements.txt index 7c81c5f9..49bfbf40 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" +--extra-index-url https://developer.intel.com/ipex-whl-stable-cpu +oneccl_bind_pt==2.0.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index d02bfc5b..505c7d40 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -57,7 +57,15 @@ def initialize_torch_distributed(): options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) else: - backend = "gloo" + try: + import oneccl_bindings_for_pytorch + + backend = "ccl" + if os.getenv("CCL_WORKER_COUNT", None) is None: + os.environ["CCL_WORKER_COUNT"] = str(1) + except ImportError: + backend = "gloo" + options = None if WORLD_SIZE == 1: