remove flag

This commit is contained in:
Cyril Vallez 2025-01-17 12:26:51 +00:00
parent 266377b328
commit 32488c1a11
No known key found for this signature in database
2 changed files with 28 additions and 45 deletions

View File

@ -30,7 +30,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING
from text_generation_server.models.globals import ATTENTION
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
@ -368,7 +368,6 @@ def get_model(
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
global USE_CUSTOM_MODELING
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
@ -376,24 +375,11 @@ def get_model(
model_type = config_dict.get("model_type", None)
transformers_causal_lm_class = CausalLM
if not USE_CUSTOM_MODELING:
logger.info(
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
)
# Fast transformers path
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
if transformers_model_class._supports_flex_attn:
logger.info(
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
)
transformers_causal_lm_class = TransformersFlashCausalLM
else:
logger.info(
f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic "
"format with padding (two dimensions for batch size and sequence length). This is expected to be slow."
)
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None:
@ -613,7 +599,7 @@ def get_model(
)
if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
@ -678,7 +664,7 @@ def get_model(
or model_type == GPT2
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
@ -729,7 +715,7 @@ def get_model(
batch_class=CausalLMBatchKeysLast,
)
elif model_type == GPT2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
try:
return FlashCausalLM(
model_id=model_id,
@ -765,7 +751,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GPTJ:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
try:
return FlashCausalLM(
model_id=model_id,
@ -801,7 +787,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GPT_NEOX:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
@ -839,7 +825,7 @@ def get_model(
)
elif model_type == PHI:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
@ -862,7 +848,7 @@ def get_model(
)
elif model_type == PHI_MOE:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
@ -886,7 +872,7 @@ def get_model(
)
elif model_type == "phi-msft":
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention"
)
@ -908,7 +894,7 @@ def get_model(
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
@ -934,7 +920,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == GEMMA:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
@ -960,7 +946,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GEMMA2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
@ -987,7 +973,7 @@ def get_model(
)
if model_type == COHERE:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
@ -1012,7 +998,7 @@ def get_model(
)
if model_type == DBRX:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
@ -1041,7 +1027,7 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model")
return FlashCausalLM(
@ -1090,7 +1076,7 @@ def get_model(
)
if model_type == MISTRAL:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
@ -1115,7 +1101,7 @@ def get_model(
)
if model_type == MIXTRAL:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
@ -1140,7 +1126,7 @@ def get_model(
)
if model_type == STARCODER2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
@ -1167,7 +1153,7 @@ def get_model(
)
if model_type == QWEN2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
@ -1219,7 +1205,7 @@ def get_model(
},
)
if model_type == IDEFICS:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return IdeficsCausalLM(
model_id,
revision,
@ -1243,7 +1229,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
if model_type == MLLAMA:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
@ -1259,7 +1245,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
@ -1277,7 +1263,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
@ -1296,7 +1282,7 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if FLASH_ATTENTION:
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,

View File

@ -68,6 +68,3 @@ def get_adapter_to_index():
global ADAPTER_TO_INDEX
return ADAPTER_TO_INDEX
USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true")
USE_CUSTOM_MODELING = USE_CUSTOM_MODELING.lower() == "true" or USE_CUSTOM_MODELING == "1"