mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
enable gemma/gemma2/phi in intel platform
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
af16320e66
commit
36077d8ff9
@ -14,6 +14,7 @@ def attention(
|
|||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
):
|
):
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return ipex.llm.functional.varlen_attention(
|
return ipex.llm.functional.varlen_attention(
|
||||||
@ -28,7 +29,7 @@ def attention(
|
|||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
causal,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user