mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
remove flag
This commit is contained in:
parent
266377b328
commit
32488c1a11
@ -30,7 +30,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
@ -368,7 +368,6 @@ def get_model(
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
global FLASH_ATTENTION
|
||||
global USE_CUSTOM_MODELING
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
@ -376,24 +375,11 @@ def get_model(
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
transformers_causal_lm_class = CausalLM
|
||||
if not USE_CUSTOM_MODELING:
|
||||
logger.info(
|
||||
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
||||
)
|
||||
|
||||
# Fast transformers path
|
||||
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||
|
||||
if transformers_model_class._supports_flex_attn:
|
||||
logger.info(
|
||||
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
|
||||
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
|
||||
)
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
else:
|
||||
logger.info(
|
||||
f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic "
|
||||
"format with padding (two dimensions for batch size and sequence length). This is expected to be slow."
|
||||
)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is None:
|
||||
@ -613,7 +599,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == DEEPSEEK_V2:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
head_size = max(
|
||||
config_dict.get("qk_nope_dim", 128)
|
||||
+ config_dict.get("qk_rope_dim", 64),
|
||||
@ -678,7 +664,7 @@ def get_model(
|
||||
or model_type == GPT2
|
||||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashSantacoderForCausalLM,
|
||||
@ -729,7 +715,7 @@ def get_model(
|
||||
batch_class=CausalLMBatchKeysLast,
|
||||
)
|
||||
elif model_type == GPT2:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
try:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -765,7 +751,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == GPTJ:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
try:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -801,7 +787,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == GPT_NEOX:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
GPTNeoXConfig,
|
||||
)
|
||||
@ -839,7 +825,7 @@ def get_model(
|
||||
)
|
||||
|
||||
elif model_type == PHI:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashPhiForCausalLM,
|
||||
@ -862,7 +848,7 @@ def get_model(
|
||||
)
|
||||
|
||||
elif model_type == PHI_MOE:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
@ -886,7 +872,7 @@ def get_model(
|
||||
)
|
||||
|
||||
elif model_type == "phi-msft":
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
raise NotImplementedError(
|
||||
"Legacy phi-msft is not supported with Flash Attention"
|
||||
)
|
||||
@ -908,7 +894,7 @@ def get_model(
|
||||
or model_type == PHI3
|
||||
or model_type == GRANITE
|
||||
):
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
@ -934,7 +920,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type == GEMMA:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemmaForCausalLM,
|
||||
@ -960,7 +946,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == GEMMA2:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemma2ForCausalLM,
|
||||
@ -987,7 +973,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == COHERE:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashCohereForCausalLM,
|
||||
@ -1012,7 +998,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == DBRX:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashDbrxForCausalLM,
|
||||
@ -1041,7 +1027,7 @@ def get_model(
|
||||
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
||||
if sharded:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
if config_dict.get("alibi", False):
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
return FlashCausalLM(
|
||||
@ -1090,7 +1076,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == MISTRAL:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashMistralForCausalLM,
|
||||
@ -1115,7 +1101,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == MIXTRAL:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashMixtralForCausalLM,
|
||||
@ -1140,7 +1126,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == STARCODER2:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashStarcoder2ForCausalLM,
|
||||
@ -1167,7 +1153,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == QWEN2:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2ForCausalLM,
|
||||
@ -1219,7 +1205,7 @@ def get_model(
|
||||
},
|
||||
)
|
||||
if model_type == IDEFICS:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return IdeficsCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
@ -1243,7 +1229,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
if model_type == MLLAMA:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return MllamaCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
@ -1259,7 +1245,7 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||
if model_type == IDEFICS2:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Idefics2ForConditionalGeneration,
|
||||
@ -1277,7 +1263,7 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == PALIGEMMA:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=PaliGemmaForConditionalGeneration,
|
||||
@ -1296,7 +1282,7 @@ def get_model(
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == LLAVA_NEXT:
|
||||
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
|
@ -68,6 +68,3 @@ def get_adapter_to_index():
|
||||
global ADAPTER_TO_INDEX
|
||||
return ADAPTER_TO_INDEX
|
||||
|
||||
|
||||
USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true")
|
||||
USE_CUSTOM_MODELING = USE_CUSTOM_MODELING.lower() == "true" or USE_CUSTOM_MODELING == "1"
|
||||
|
Loading…
Reference in New Issue
Block a user