device check

This commit is contained in:
Cyril Vallez 2025-01-20 15:55:31 +01:00
parent 52afdcc281
commit 9af3ea4b70
No known key found for this signature in database
2 changed files with 8 additions and 7 deletions

View File

@ -382,7 +382,10 @@ def get_model(
transformers_model_class = getattr(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
)
if transformers_model_class._supports_flex_attn:
accelerator_available = torch.cuda.is_available() or (
hasattr(torch, "xpu") and torch.xpu.is_available()
)
if transformers_model_class._supports_flex_attn and accelerator_available:
transformers_causal_lm_class = TransformersFlashCausalLM
if (
not FLASH_ATTENTION
@ -390,7 +393,7 @@ def get_model(
and len(lora_adapter_ids) > 0
):
raise ValueError(
"Transformers backend AutoModel do not support `lora_adapter_ids`."
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
)
quantization_config = config_dict.get("quantization_config", None)

View File

@ -111,11 +111,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
device = torch.device("xpu")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
raise ValueError(
"Flash `Transformers` modeling backend is not available on cpu."
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,