diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index aa43107f..7f8268a9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -768,7 +768,10 @@ class FlashCausalLM(Model): empty_cache() element_size = torch.tensor([], dtype=dtype).element_size() - x = BLOCK_SIZE // element_size + if SYSTEM == "ipex" and device.type == "xpu": + x = 1 + else: + x = BLOCK_SIZE // element_size if SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 75c7203a..2d0f9fcc 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -37,9 +37,10 @@ class FlashGPT2(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 76c522e3..9366706f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -37,9 +37,10 @@ class FlashLlama(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 78a09cf5..16778ada 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -41,9 +41,10 @@ class BaseFlashMistral(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 9c82bf52..87ae570c 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -36,9 +36,10 @@ class FlashNeoXSharded(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e8087f23..6ed1f6f7 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -37,9 +37,10 @@ class FlashRWSharded(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 83a6b92c..ab1e4516 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -40,9 +40,10 @@ class FlashSantacoderSharded(FlashCausalLM): elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU")