diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 21aa1f8b..42ec1b3f 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -37,6 +37,7 @@ def tgi_flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], # This needs to stay as it is passed as a positional arg in transformers kv_cache: List[KVCache], kv_head_mapping: torch.Tensor, slots: torch.Tensor,