Add flash_attention argument options for Mistral (#145)

Co-authored-by: Karol Damaszke <karol.damaszke@intel.com>
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Jimin Ha 2024-05-27 11:00:42 -07:00 committed by GitHub
parent 2eb2da5f02
commit 1023de8048
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 1 deletions

View File

@ -137,6 +137,9 @@ Support for other models from Optimum Habana will be added successively.
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
| USE_FLASH_ATTENTION | True/False | False | Whether to enable Habana Flash Attention, provided that the model supports it. Currently only llama and mistral supports this feature. Please refer to https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa |
| FLASH_ATTENTION_RECOMPUTE | True/False | False | Whether to enable Habana Flash Attention in recompute mode on first token generation. |
</div>
## Profiler

View File

@ -662,10 +662,15 @@ class CausalLM(Model):
"return_dict": True,
}
if model.config.model_type == "llama":
if model.config.model_type in ["llama", "mistral"]:
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
kwargs["use_flash_attention"] = True
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true":
kwargs["flash_attention_recompute"] = True
self.speculate = get_speculate()
super(CausalLM, self).__init__(