From 3a05bac2254537c61b70879b7d814d399862a8b3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sat, 7 Oct 2023 00:13:46 -0700 Subject: [PATCH] add torch ccl to support TP for bfloat16, gloo does not support bfloat16 Signed-off-by: Wang, Yi A --- server/requirements.txt | 2 ++ server/text_generation_server/utils/dist.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/server/requirements.txt b/server/requirements.txt index 7c81c5f9..49bfbf40 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" +--extra-index-url https://developer.intel.com/ipex-whl-stable-cpu +oneccl_bind_pt==2.0.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index d02bfc5b..505c7d40 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -57,7 +57,15 @@ def initialize_torch_distributed(): options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) else: - backend = "gloo" + try: + import oneccl_bindings_for_pytorch + + backend = "ccl" + if os.getenv("CCL_WORKER_COUNT", None) is None: + os.environ["CCL_WORKER_COUNT"] = str(1) + except ImportError: + backend = "gloo" + options = None if WORLD_SIZE == 1: