mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
enable gemma/gemma2/phi in intel platform
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
af16320e66
commit
36077d8ff9
@ -14,6 +14,7 @@ def attention(
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
@ -28,7 +29,7 @@ def attention(
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
||||
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user