enable gemma/gemma2/phi in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-06-27 19:33:17 -07:00
parent af16320e66
commit 36077d8ff9
4 changed files with 26 additions and 1 deletions

View File

@ -14,6 +14,7 @@ def attention(
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True,
): ):
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention( return ipex.llm.functional.varlen_attention(
@ -28,7 +29,7 @@ def attention(
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
False, False,
None, None,
) )

View File

@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGemma is only available on GPU") raise NotImplementedError("FlashGemma is only available on GPU")

View File

@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGemma2 is only available on GPU") raise NotImplementedError("FlashGemma2 is only available on GPU")

View File

@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashPhi is only available on GPU") raise NotImplementedError("FlashPhi is only available on GPU")