mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
remove flag
This commit is contained in:
parent
266377b328
commit
32488c1a11
@ -30,7 +30,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
|
|||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
BloomForCausalLM,
|
BloomForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING
|
from text_generation_server.models.globals import ATTENTION
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||||
@ -368,7 +368,6 @@ def get_model(
|
|||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
global FLASH_ATTENTION
|
global FLASH_ATTENTION
|
||||||
global USE_CUSTOM_MODELING
|
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
@ -376,24 +375,11 @@ def get_model(
|
|||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
transformers_causal_lm_class = CausalLM
|
transformers_causal_lm_class = CausalLM
|
||||||
if not USE_CUSTOM_MODELING:
|
|
||||||
logger.info(
|
|
||||||
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
|
||||||
)
|
|
||||||
|
|
||||||
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
# Fast transformers path
|
||||||
|
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||||
if transformers_model_class._supports_flex_attn:
|
if transformers_model_class._supports_flex_attn:
|
||||||
logger.info(
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
|
|
||||||
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
|
|
||||||
)
|
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic "
|
|
||||||
"format with padding (two dimensions for batch size and sequence length). This is expected to be slow."
|
|
||||||
)
|
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
@ -613,7 +599,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
head_size = max(
|
head_size = max(
|
||||||
config_dict.get("qk_nope_dim", 128)
|
config_dict.get("qk_nope_dim", 128)
|
||||||
+ config_dict.get("qk_rope_dim", 64),
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
@ -678,7 +664,7 @@ def get_model(
|
|||||||
or model_type == GPT2
|
or model_type == GPT2
|
||||||
and model_id.startswith("bigcode/")
|
and model_id.startswith("bigcode/")
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashSantacoderForCausalLM,
|
model_class=FlashSantacoderForCausalLM,
|
||||||
@ -729,7 +715,7 @@ def get_model(
|
|||||||
batch_class=CausalLMBatchKeysLast,
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
try:
|
try:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -765,7 +751,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == GPTJ:
|
elif model_type == GPTJ:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
try:
|
try:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -801,7 +787,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
GPTNeoXConfig,
|
GPTNeoXConfig,
|
||||||
)
|
)
|
||||||
@ -839,7 +825,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == PHI:
|
elif model_type == PHI:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashPhiForCausalLM,
|
model_class=FlashPhiForCausalLM,
|
||||||
@ -862,7 +848,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == PHI_MOE:
|
elif model_type == PHI_MOE:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashLlamaForCausalLM,
|
model_class=FlashLlamaForCausalLM,
|
||||||
@ -886,7 +872,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "phi-msft":
|
elif model_type == "phi-msft":
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Legacy phi-msft is not supported with Flash Attention"
|
"Legacy phi-msft is not supported with Flash Attention"
|
||||||
)
|
)
|
||||||
@ -908,7 +894,7 @@ def get_model(
|
|||||||
or model_type == PHI3
|
or model_type == PHI3
|
||||||
or model_type == GRANITE
|
or model_type == GRANITE
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashLlamaForCausalLM,
|
model_class=FlashLlamaForCausalLM,
|
||||||
@ -934,7 +920,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type == GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashGemmaForCausalLM,
|
model_class=FlashGemmaForCausalLM,
|
||||||
@ -960,7 +946,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == GEMMA2:
|
elif model_type == GEMMA2:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashGemma2ForCausalLM,
|
model_class=FlashGemma2ForCausalLM,
|
||||||
@ -987,7 +973,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == COHERE:
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashCohereForCausalLM,
|
model_class=FlashCohereForCausalLM,
|
||||||
@ -1012,7 +998,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == DBRX:
|
if model_type == DBRX:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashDbrxForCausalLM,
|
model_class=FlashDbrxForCausalLM,
|
||||||
@ -1041,7 +1027,7 @@ def get_model(
|
|||||||
|
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
if config_dict.get("alibi", False):
|
if config_dict.get("alibi", False):
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
@ -1090,7 +1076,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashMistralForCausalLM,
|
model_class=FlashMistralForCausalLM,
|
||||||
@ -1115,7 +1101,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashMixtralForCausalLM,
|
model_class=FlashMixtralForCausalLM,
|
||||||
@ -1140,7 +1126,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashStarcoder2ForCausalLM,
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
@ -1167,7 +1153,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2ForCausalLM,
|
model_class=Qwen2ForCausalLM,
|
||||||
@ -1219,7 +1205,7 @@ def get_model(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if model_type == IDEFICS:
|
if model_type == IDEFICS:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return IdeficsCausalLM(
|
return IdeficsCausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -1243,7 +1229,7 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
if model_type == MLLAMA:
|
if model_type == MLLAMA:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return MllamaCausalLM(
|
return MllamaCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=MllamaForConditionalGeneration,
|
model_class=MllamaForConditionalGeneration,
|
||||||
@ -1259,7 +1245,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Idefics2ForConditionalGeneration,
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
@ -1277,7 +1263,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=PaliGemmaForConditionalGeneration,
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
@ -1296,7 +1282,7 @@ def get_model(
|
|||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_class=LlavaNextForConditionalGeneration,
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -68,6 +68,3 @@ def get_adapter_to_index():
|
|||||||
global ADAPTER_TO_INDEX
|
global ADAPTER_TO_INDEX
|
||||||
return ADAPTER_TO_INDEX
|
return ADAPTER_TO_INDEX
|
||||||
|
|
||||||
|
|
||||||
USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true")
|
|
||||||
USE_CUSTOM_MODELING = USE_CUSTOM_MODELING.lower() == "true" or USE_CUSTOM_MODELING == "1"
|
|
||||||
|
Loading…
Reference in New Issue
Block a user