add torch ccl to support TP for bfloat16, gloo does not support bfloat16

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2023-10-07 00:13:46 -07:00
parent 00b8f36fba
commit 3a05bac225
2 changed files with 11 additions and 1 deletions

View File

@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
--extra-index-url https://developer.intel.com/ipex-whl-stable-cpu
oneccl_bind_pt==2.0.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -57,7 +57,15 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60) options._timeout = timedelta(seconds=60)
else: else:
try:
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo" backend = "gloo"
options = None options = None
if WORLD_SIZE == 1: if WORLD_SIZE == 1: