Merge branch 'habana-main' into 2.3.0

This commit is contained in:
yuanwu2017 2024-10-28 04:37:07 +08:00 committed by GitHub
commit c23584f626
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -731,13 +731,11 @@ class CausalLM(Model):
} }
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gemma"]: if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon"]:
if model.config.model_type not in ["falcon"]: if model.config.model_type not in ["falcon"]:
self.kwargs["attn_softmax_bf16"] = True self.kwargs["attn_softmax_bf16"] = True
if model.config.model_type not in ["gemma"]: self.kwargs["trim_logits"] = True
self.kwargs["trim_logits"] = True
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
self.kwargs["use_flash_attention"] = True self.kwargs["use_flash_attention"] = True