mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Add flash_attention argument options for Mistral (#145)
Co-authored-by: Karol Damaszke <karol.damaszke@intel.com> Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
2eb2da5f02
commit
1023de8048
@ -137,6 +137,9 @@ Support for other models from Optimum Habana will be added successively.
|
|||||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||||
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||||
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
|
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
|
||||||
|
| USE_FLASH_ATTENTION | True/False | False | Whether to enable Habana Flash Attention, provided that the model supports it. Currently only llama and mistral supports this feature. Please refer to https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa |
|
||||||
|
| FLASH_ATTENTION_RECOMPUTE | True/False | False | Whether to enable Habana Flash Attention in recompute mode on first token generation. |
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## Profiler
|
## Profiler
|
||||||
|
@ -662,10 +662,15 @@ class CausalLM(Model):
|
|||||||
"return_dict": True,
|
"return_dict": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.config.model_type == "llama":
|
if model.config.model_type in ["llama", "mistral"]:
|
||||||
kwargs["attn_softmax_bf16"] = True
|
kwargs["attn_softmax_bf16"] = True
|
||||||
kwargs["trim_logits"] = True
|
kwargs["trim_logits"] = True
|
||||||
|
|
||||||
|
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
||||||
|
kwargs["use_flash_attention"] = True
|
||||||
|
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true":
|
||||||
|
kwargs["flash_attention_recompute"] = True
|
||||||
|
|
||||||
self.speculate = get_speculate()
|
self.speculate = get_speculate()
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
Loading…
Reference in New Issue
Block a user