device check

This commit is contained in:
Cyril Vallez 2025-01-20 15:55:31 +01:00
parent 52afdcc281
commit 9af3ea4b70
No known key found for this signature in database
2 changed files with 8 additions and 7 deletions

View File

@ -382,7 +382,10 @@ def get_model(
transformers_model_class = getattr( transformers_model_class = getattr(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
) )
if transformers_model_class._supports_flex_attn: accelerator_available = torch.cuda.is_available() or (
hasattr(torch, "xpu") and torch.xpu.is_available()
)
if transformers_model_class._supports_flex_attn and accelerator_available:
transformers_causal_lm_class = TransformersFlashCausalLM transformers_causal_lm_class = TransformersFlashCausalLM
if ( if (
not FLASH_ATTENTION not FLASH_ATTENTION
@ -390,7 +393,7 @@ def get_model(
and len(lora_adapter_ids) > 0 and len(lora_adapter_ids) > 0
): ):
raise ValueError( raise ValueError(
"Transformers backend AutoModel do not support `lora_adapter_ids`." "Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
) )
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)

View File

@ -111,11 +111,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
device = torch.device("xpu") device = torch.device("xpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: raise ValueError(
raise ValueError("quantization is not available on CPU") "Flash `Transformers` modeling backend is not available on cpu."
)
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,