remove flag

This commit is contained in:
Cyril Vallez 2025-01-17 12:26:51 +00:00
parent 266377b328
commit 32488c1a11
No known key found for this signature in database
2 changed files with 28 additions and 45 deletions

View File

@ -30,7 +30,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import ( from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM, BloomForCausalLM,
) )
from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING from text_generation_server.models.globals import ATTENTION
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import ( from text_generation_server.models.custom_modeling.neox_modeling import (
@ -368,7 +368,6 @@ def get_model(
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
global FLASH_ATTENTION global FLASH_ATTENTION
global USE_CUSTOM_MODELING
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
@ -376,24 +375,11 @@ def get_model(
model_type = config_dict.get("model_type", None) model_type = config_dict.get("model_type", None)
transformers_causal_lm_class = CausalLM transformers_causal_lm_class = CausalLM
if not USE_CUSTOM_MODELING:
logger.info(
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
)
# Fast transformers path
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
if transformers_model_class._supports_flex_attn: if transformers_model_class._supports_flex_attn:
logger.info(
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
)
transformers_causal_lm_class = TransformersFlashCausalLM transformers_causal_lm_class = TransformersFlashCausalLM
else:
logger.info(
f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic "
"format with padding (two dimensions for batch size and sequence length). This is expected to be slow."
)
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None: if quantization_config is None:
@ -613,7 +599,7 @@ def get_model(
) )
if model_type == DEEPSEEK_V2: if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
head_size = max( head_size = max(
config_dict.get("qk_nope_dim", 128) config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64), + config_dict.get("qk_rope_dim", 64),
@ -678,7 +664,7 @@ def get_model(
or model_type == GPT2 or model_type == GPT2
and model_id.startswith("bigcode/") and model_id.startswith("bigcode/")
): ):
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashSantacoderForCausalLM, model_class=FlashSantacoderForCausalLM,
@ -729,7 +715,7 @@ def get_model(
batch_class=CausalLMBatchKeysLast, batch_class=CausalLMBatchKeysLast,
) )
elif model_type == GPT2: elif model_type == GPT2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
try: try:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -765,7 +751,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GPTJ: elif model_type == GPTJ:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
try: try:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -801,7 +787,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig, GPTNeoXConfig,
) )
@ -839,7 +825,7 @@ def get_model(
) )
elif model_type == PHI: elif model_type == PHI:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashPhiForCausalLM, model_class=FlashPhiForCausalLM,
@ -862,7 +848,7 @@ def get_model(
) )
elif model_type == PHI_MOE: elif model_type == PHI_MOE:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashLlamaForCausalLM, model_class=FlashLlamaForCausalLM,
@ -886,7 +872,7 @@ def get_model(
) )
elif model_type == "phi-msft": elif model_type == "phi-msft":
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
raise NotImplementedError( raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention" "Legacy phi-msft is not supported with Flash Attention"
) )
@ -908,7 +894,7 @@ def get_model(
or model_type == PHI3 or model_type == PHI3
or model_type == GRANITE or model_type == GRANITE
): ):
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashLlamaForCausalLM, model_class=FlashLlamaForCausalLM,
@ -934,7 +920,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashGemmaForCausalLM, model_class=FlashGemmaForCausalLM,
@ -960,7 +946,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GEMMA2: elif model_type == GEMMA2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashGemma2ForCausalLM, model_class=FlashGemma2ForCausalLM,
@ -987,7 +973,7 @@ def get_model(
) )
if model_type == COHERE: if model_type == COHERE:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashCohereForCausalLM, model_class=FlashCohereForCausalLM,
@ -1012,7 +998,7 @@ def get_model(
) )
if model_type == DBRX: if model_type == DBRX:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashDbrxForCausalLM, model_class=FlashDbrxForCausalLM,
@ -1041,7 +1027,7 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded: if sharded:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
if config_dict.get("alibi", False): if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashCausalLM( return FlashCausalLM(
@ -1090,7 +1076,7 @@ def get_model(
) )
if model_type == MISTRAL: if model_type == MISTRAL:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashMistralForCausalLM, model_class=FlashMistralForCausalLM,
@ -1115,7 +1101,7 @@ def get_model(
) )
if model_type == MIXTRAL: if model_type == MIXTRAL:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashMixtralForCausalLM, model_class=FlashMixtralForCausalLM,
@ -1140,7 +1126,7 @@ def get_model(
) )
if model_type == STARCODER2: if model_type == STARCODER2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashStarcoder2ForCausalLM, model_class=FlashStarcoder2ForCausalLM,
@ -1167,7 +1153,7 @@ def get_model(
) )
if model_type == QWEN2: if model_type == QWEN2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2ForCausalLM, model_class=Qwen2ForCausalLM,
@ -1219,7 +1205,7 @@ def get_model(
}, },
) )
if model_type == IDEFICS: if model_type == IDEFICS:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return IdeficsCausalLM( return IdeficsCausalLM(
model_id, model_id,
revision, revision,
@ -1243,7 +1229,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
if model_type == MLLAMA: if model_type == MLLAMA:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return MllamaCausalLM( return MllamaCausalLM(
model_id=model_id, model_id=model_id,
model_class=MllamaForConditionalGeneration, model_class=MllamaForConditionalGeneration,
@ -1259,7 +1245,7 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Idefics2ForConditionalGeneration, model_class=Idefics2ForConditionalGeneration,
@ -1277,7 +1263,7 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA: if model_type == PALIGEMMA:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=PaliGemmaForConditionalGeneration, model_class=PaliGemmaForConditionalGeneration,
@ -1296,7 +1282,7 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT: if model_type == LLAVA_NEXT:
if FLASH_ATTENTION and USE_CUSTOM_MODELING: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration, model_class=LlavaNextForConditionalGeneration,
model_id=model_id, model_id=model_id,

View File

@ -68,6 +68,3 @@ def get_adapter_to_index():
global ADAPTER_TO_INDEX global ADAPTER_TO_INDEX
return ADAPTER_TO_INDEX return ADAPTER_TO_INDEX
USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true")
USE_CUSTOM_MODELING = USE_CUSTOM_MODELING.lower() == "true" or USE_CUSTOM_MODELING == "1"