diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5069fff6..612ad8b3 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -382,7 +382,10 @@ def get_model( transformers_model_class = getattr( transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] ) - if transformers_model_class._supports_flex_attn: + accelerator_available = torch.cuda.is_available() or ( + hasattr(torch, "xpu") and torch.xpu.is_available() + ) + if transformers_model_class._supports_flex_attn and accelerator_available: transformers_causal_lm_class = TransformersFlashCausalLM if ( not FLASH_ATTENTION @@ -390,7 +393,7 @@ def get_model( and len(lora_adapter_ids) > 0 ): raise ValueError( - "Transformers backend AutoModel do not support `lora_adapter_ids`." + "Flash `Transformers` modeling backend does not support `lora_adapter_ids`." ) quantization_config = config_dict.get("quantization_config", None) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 17f47e5e..647fabc2 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -111,11 +111,9 @@ class TransformersFlashCausalLM(FlashCausalLM): device = torch.device("xpu") dtype = torch.float16 if dtype is None else dtype else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + raise ValueError( + "Flash `Transformers` modeling backend is not available on cpu." + ) tokenizer = AutoTokenizer.from_pretrained( model_id,