diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index bfab0119..7f086b68 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -14,6 +14,7 @@ def attention( max_s, softmax_scale, window_size_left=-1, + causal=True, ): # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( @@ -28,7 +29,7 @@ def attention( 0.0, softmax_scale, False, - True, + causal, False, None, ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index aa1ae9ac..7e2b8780 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma is only available on GPU") diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py index 9608113b..86cfc7e2 100644 --- a/server/text_generation_server/models/flash_gemma2.py +++ b/server/text_generation_server/models/flash_gemma2.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGemma2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 7e108d05..a530d1c3 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashPhi is only available on GPU")