diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d8a923f02..6efcab98d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -201,9 +201,8 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) -FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or ( - hasattr(torch, "xpu") and torch.xpu.is_available() -) +FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == "ipex" + try: from text_generation_server.models.transformers_flash_causal_lm import ( TransformersFlashCausalLM, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 8fc4ad243..8c21ed58c 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -12,7 +12,7 @@ from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION - +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -115,8 +115,11 @@ class TransformersFlashCausalLM(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = default_dtype if dtype is None else dtype else: raise ValueError( diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 850addf11..280fa0bd3 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -14,6 +14,7 @@ from text_generation_server.layers.attention import paged_attention, attention, from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION import torch.nn.functional as F +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -174,8 +175,11 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = default_dtype if dtype is None else dtype else: raise ValueError(