enable qwen2 in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-06-27 06:01:07 -07:00
parent 886bfab23d
commit 6982f9bcb1

View File

@ -19,6 +19,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")