diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d2e36a8d..11a14c2f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -33,6 +33,7 @@ from text_generation_server.utils.adapter import ( from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 1)) # Disable gradients torch.set_grad_enabled(False) @@ -49,6 +50,8 @@ def get_model( max_input_tokens: int, ) -> Model: adapt_transformers_to_gaudi() + if SDP_ON_BF16 == 1: + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if speculate is not None: set_speculate(speculate) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 2353fff2..fab25b0c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -698,6 +698,7 @@ class CausalLM(Model): htorch.core.hpu_set_env() if world_size > 1: + os.environ.setdefault("DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1") model = self.get_deepspeed_model(model_id, dtype, revision) model = hq_env.prepare_model_for_quantization(model) else: