mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
simplify check
This commit is contained in:
parent
44b367937b
commit
266377b328
@ -382,12 +382,8 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||||
# Ugly check but works in the meantime
|
|
||||||
model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py")
|
|
||||||
with open(model_path) as file:
|
|
||||||
has_fa2_class = f"FlashAttention2(" in file.read()
|
|
||||||
|
|
||||||
if transformers_model_class._supports_flash_attn_2 and not has_fa2_class:
|
if transformers_model_class._supports_flex_attn:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
|
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
|
||||||
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
|
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
|
||||||
|
Loading…
Reference in New Issue
Block a user