mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Fix mistralai/Mistral-7B-Instruct failed issue (#284)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
c35810d6f0
commit
20ea73c6d4
@ -33,6 +33,7 @@ from text_generation_server.utils.adapter import (
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
|
||||
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 1))
|
||||
# Disable gradients
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
@ -49,6 +50,8 @@ def get_model(
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
adapt_transformers_to_gaudi()
|
||||
if SDP_ON_BF16 == 1:
|
||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||
|
||||
if speculate is not None:
|
||||
set_speculate(speculate)
|
||||
|
@ -698,6 +698,7 @@ class CausalLM(Model):
|
||||
htorch.core.hpu_set_env()
|
||||
|
||||
if world_size > 1:
|
||||
os.environ.setdefault("DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1")
|
||||
model = self.get_deepspeed_model(model_id, dtype, revision)
|
||||
model = hq_env.prepare_model_for_quantization(model)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user