diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a1359212..c97b0006 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -382,12 +382,8 @@ def get_model( ) transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - # Ugly check but works in the meantime - model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py") - with open(model_path) as file: - has_fa2_class = f"FlashAttention2(" in file.read() - if transformers_model_class._supports_flash_attn_2 and not has_fa2_class: + if transformers_model_class._supports_flex_attn: logger.info( f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " "batch and sequence length). All TGI's batching/caching optimizations are enabled."