From c6565e825901cde5e0da307596c6500b08d9677a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 May 2024 16:37:38 +0000 Subject: [PATCH] format --- .../text_generation_server/models/__init__.py | 32 ++++++++++++------- .../models/flash_causal_lm.py | 5 ++- .../models/flash_mistral.py | 5 ++- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 8878ad15..ff75a635 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -79,12 +79,14 @@ try: from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx - from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA - + from text_generation_server.utils.flash_attn import ( + HAS_FLASH_ATTN_V2_CUDA, + HAS_FLASH_ATTN_V2_ROCM, + ) except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False HAS_FLASH_ATTN_V2_CUDA = False + HAS_FLASH_ATTN_V2_ROCM = False if FLASH_ATTENTION: __all__.append(FlashGPT2) @@ -539,8 +541,10 @@ def get_model( if model_type == "mistral": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashMistral( model_id, revision, @@ -564,8 +568,10 @@ def get_model( if model_type == "mixtral": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashMixtral( model_id, revision, @@ -589,8 +595,10 @@ def get_model( if model_type == "starcoder2": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashStarcoder2( model_id, revision, @@ -615,8 +623,10 @@ def get_model( if model_type == "qwen2": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 333efe33..45ddd856 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -890,6 +890,9 @@ class FlashCausalLM(Model): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) kv_cache = get_cache_manager().kv_cache + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -899,7 +902,7 @@ class FlashCausalLM(Model): ), kv_cache=get_cache_manager().kv_cache, block_tables=None, - input_lengths=None, + input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 30ae95c9..e6125e29 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) kv_cache = get_cache_manager().kv_cache + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM): ), kv_cache=get_cache_manager().kv_cache, block_tables=None, - input_lengths=None, + input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None,