simplify check

This commit is contained in:
Cyril Vallez 2025-01-15 18:05:55 +00:00
parent 44b367937b
commit 266377b328
No known key found for this signature in database

View File

@ -382,12 +382,8 @@ def get_model(
)
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
# Ugly check but works in the meantime
model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py")
with open(model_path) as file:
has_fa2_class = f"FlashAttention2(" in file.read()
if transformers_model_class._supports_flash_attn_2 and not has_fa2_class:
if transformers_model_class._supports_flex_attn:
logger.info(
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
"batch and sequence length). All TGI's batching/caching optimizations are enabled."