mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
format
This commit is contained in:
parent
b5f1c9de06
commit
c6565e8259
@ -79,12 +79,14 @@ try:
|
|||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import (
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA,
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM,
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashGPT2)
|
__all__.append(FlashGPT2)
|
||||||
@ -539,8 +541,10 @@ def get_model(
|
|||||||
if model_type == "mistral":
|
if model_type == "mistral":
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
or HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
or HAS_FLASH_ATTN_V2_ROCM
|
||||||
|
):
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -564,8 +568,10 @@ def get_model(
|
|||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
or HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
or HAS_FLASH_ATTN_V2_ROCM
|
||||||
|
):
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -589,8 +595,10 @@ def get_model(
|
|||||||
if model_type == "starcoder2":
|
if model_type == "starcoder2":
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
or HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
or HAS_FLASH_ATTN_V2_ROCM
|
||||||
|
):
|
||||||
return FlashStarcoder2(
|
return FlashStarcoder2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -615,8 +623,10 @@ def get_model(
|
|||||||
if model_type == "qwen2":
|
if model_type == "qwen2":
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (
|
||||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
or HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
or HAS_FLASH_ATTN_V2_ROCM
|
||||||
|
):
|
||||||
return FlashQwen2(
|
return FlashQwen2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -890,6 +890,9 @@ class FlashCausalLM(Model):
|
|||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
|
||||||
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -899,7 +902,7 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=None,
|
input_lengths=input_lengths,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
|
||||||
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
),
|
),
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=None,
|
input_lengths=input_lengths,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user