mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
ipex cpu could also support in function
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
50282e3cc1
commit
74ad8ed300
@ -201,9 +201,8 @@ except ImportError as e:
|
|||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
|
|
||||||
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or (
|
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == "ipex"
|
||||||
hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||||
TransformersFlashCausalLM,
|
TransformersFlashCausalLM,
|
||||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils import initialize_torch_distributed
|
|||||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -115,8 +115,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -14,6 +14,7 @@ from text_generation_server.layers.attention import paged_attention, attention,
|
|||||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -174,8 +175,11 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
Reference in New Issue
Block a user