ipex cpu could also support in function

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-13 20:49:35 -07:00
parent 50282e3cc1
commit 74ad8ed300
3 changed files with 14 additions and 8 deletions

View File

@ -201,9 +201,8 @@ except ImportError as e:
if MAMBA_AVAILABLE:
__all__.append(Mamba)
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or (
hasattr(torch, "xpu") and torch.xpu.is_available()
)
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == "ipex"
try:
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,

View File

@ -12,7 +12,7 @@ from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
@ -115,8 +115,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = default_dtype if dtype is None else dtype
else:
raise ValueError(

View File

@ -14,6 +14,7 @@ from text_generation_server.layers.attention import paged_attention, attention,
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION
import torch.nn.functional as F
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
@ -174,8 +175,11 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = default_dtype if dtype is None else dtype
else:
raise ValueError(