mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-06 01:12:06 +00:00
Add flash_attention argument options for Mistral (#145)
Co-authored-by: Karol Damaszke <karol.damaszke@intel.com> Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
2eb2da5f02
commit
1023de8048
@ -137,6 +137,9 @@ Support for other models from Optimum Habana will be added successively.
|
||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
|
||||
| USE_FLASH_ATTENTION | True/False | False | Whether to enable Habana Flash Attention, provided that the model supports it. Currently only llama and mistral supports this feature. Please refer to https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa |
|
||||
| FLASH_ATTENTION_RECOMPUTE | True/False | False | Whether to enable Habana Flash Attention in recompute mode on first token generation. |
|
||||
|
||||
</div>
|
||||
|
||||
## Profiler
|
||||
|
@ -662,10 +662,15 @@ class CausalLM(Model):
|
||||
"return_dict": True,
|
||||
}
|
||||
|
||||
if model.config.model_type == "llama":
|
||||
if model.config.model_type in ["llama", "mistral"]:
|
||||
kwargs["attn_softmax_bf16"] = True
|
||||
kwargs["trim_logits"] = True
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
||||
kwargs["use_flash_attention"] = True
|
||||
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true":
|
||||
kwargs["flash_attention_recompute"] = True
|
||||
|
||||
self.speculate = get_speculate()
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
|
Loading…
Reference in New Issue
Block a user