mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 19:02:09 +00:00
Fix mistralai/Mistral-7B-Instruct failed issue (#284)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
c35810d6f0
commit
20ea73c6d4
@ -33,6 +33,7 @@ from text_generation_server.utils.adapter import (
|
|||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
|
||||||
|
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 1))
|
||||||
# Disable gradients
|
# Disable gradients
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
@ -49,6 +50,8 @@ def get_model(
|
|||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
adapt_transformers_to_gaudi()
|
adapt_transformers_to_gaudi()
|
||||||
|
if SDP_ON_BF16 == 1:
|
||||||
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||||
|
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
|
@ -698,6 +698,7 @@ class CausalLM(Model):
|
|||||||
htorch.core.hpu_set_env()
|
htorch.core.hpu_set_env()
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
os.environ.setdefault("DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1")
|
||||||
model = self.get_deepspeed_model(model_id, dtype, revision)
|
model = self.get_deepspeed_model(model_id, dtype, revision)
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user