This commit is contained in:
fxmarty 2024-05-17 16:37:38 +00:00
parent b5f1c9de06
commit c6565e8259
3 changed files with 29 additions and 13 deletions

View File

@ -79,12 +79,14 @@ try:
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_CUDA,
HAS_FLASH_ATTN_V2_ROCM,
)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
if FLASH_ATTENTION:
__all__.append(FlashGPT2)
@ -539,8 +541,10 @@ def get_model(
if model_type == "mistral":
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashMistral(
model_id,
revision,
@ -564,8 +568,10 @@ def get_model(
if model_type == "mixtral":
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashMixtral(
model_id,
revision,
@ -589,8 +595,10 @@ def get_model(
if model_type == "starcoder2":
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashStarcoder2(
model_id,
revision,
@ -615,8 +623,10 @@ def get_model(
if model_type == "qwen2":
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashQwen2(
model_id,
revision,

View File

@ -890,6 +890,9 @@ class FlashCausalLM(Model):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
@ -899,7 +902,7 @@ class FlashCausalLM(Model):
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,

View File

@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,