From e72897004ac8794d4be309c777507f259704dcbc Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 23 Apr 2024 15:05:39 +0000 Subject: [PATCH] black --- server/text_generation_server/models/flash_causal_lm.py | 4 +++- server/text_generation_server/utils/flash_attn.py | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 9a3db958..31d965e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -834,7 +834,9 @@ class FlashCausalLM(Model): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) # TODO: is this correct? - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s + input_lengths = ( + torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s + ) block_tables = ( torch.arange(max_bt, dtype=torch.int32, device=self.device) .repeat(bs) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 45245357..c4a70ce5 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -85,6 +85,7 @@ except ImportError as e: logger.warning(f"Unable to use Flash Attention V2: {e}") HAS_FLASH_ATTN = True + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -93,7 +94,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: total_tokens, num_key_value_heads, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :].expand(total_tokens, num_key_value_heads, n_rep, head_dim) + hidden_states = hidden_states[:, :, None, :].expand( + total_tokens, num_key_value_heads, n_rep, head_dim + ) return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim) @@ -229,4 +232,6 @@ def attention( ) return output else: - raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})") + raise NotImplementedError( + f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})" + )