mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
device check
This commit is contained in:
parent
52afdcc281
commit
9af3ea4b70
@ -382,7 +382,10 @@ def get_model(
|
||||
transformers_model_class = getattr(
|
||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||
)
|
||||
if transformers_model_class._supports_flex_attn:
|
||||
accelerator_available = torch.cuda.is_available() or (
|
||||
hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
)
|
||||
if transformers_model_class._supports_flex_attn and accelerator_available:
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
if (
|
||||
not FLASH_ATTENTION
|
||||
@ -390,7 +393,7 @@ def get_model(
|
||||
and len(lora_adapter_ids) > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Transformers backend AutoModel do not support `lora_adapter_ids`."
|
||||
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
|
||||
)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
|
@ -111,11 +111,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
device = torch.device("xpu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
raise ValueError(
|
||||
"Flash `Transformers` modeling backend is not available on cpu."
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
Loading…
Reference in New Issue
Block a user