mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Merge branch 'habana-main' into 2.3.0
This commit is contained in:
commit
c23584f626
@ -731,13 +731,11 @@ class CausalLM(Model):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gemma"]:
|
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon"]:
|
||||||
|
|
||||||
if model.config.model_type not in ["falcon"]:
|
if model.config.model_type not in ["falcon"]:
|
||||||
self.kwargs["attn_softmax_bf16"] = True
|
self.kwargs["attn_softmax_bf16"] = True
|
||||||
|
|
||||||
if model.config.model_type not in ["gemma"]:
|
self.kwargs["trim_logits"] = True
|
||||||
self.kwargs["trim_logits"] = True
|
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
||||||
self.kwargs["use_flash_attention"] = True
|
self.kwargs["use_flash_attention"] = True
|
||||||
|
Loading…
Reference in New Issue
Block a user