small fix

This commit is contained in:
Cyril Vallez 2025-01-17 16:05:47 +00:00
parent b03d7ae951
commit b40c889360
No known key found for this signature in database

View File

@ -37,6 +37,7 @@ def tgi_flash_attention_forward(
query_states: torch.Tensor, query_states: torch.Tensor,
key_states: torch.Tensor, key_states: torch.Tensor,
value_states: torch.Tensor, value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], # This needs to stay as it is passed as a positional arg in transformers
kv_cache: List[KVCache], kv_cache: List[KVCache],
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,