From 58848cb471efcc2b8622025158bca7c024a8cd93 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Mon, 14 Oct 2024 09:28:49 -0700 Subject: [PATCH] feat: enable pytorch xpu support for non-attention models (#2561) XPU backend is available natively (without IPEX) in pytorch starting from pytorch 2.4. This commit extends TGI to cover the case when user has XPU support thru pytorch 2.4, but does not have IPEX installed. Models which don't require attention can work. For attention required models more work is needed to provide attention implementation. Tested with the following models: * teknium/OpenHermes-2.5-Mistral-7B * bigscience/bloom-560m * google/gemma-7b * google/flan-t5-xxl Signed-off-by: Dmitry Rogozhkin --- .../models/causal_lm.py | 26 +++++++++++-------- .../models/seq2seq_lm.py | 25 +++++++++++------- .../utils/import_utils.py | 5 ++++ 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 28534d0f..de2c0651 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -517,14 +517,13 @@ class CausalLM(Model): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype @@ -593,8 +592,14 @@ class CausalLM(Model): if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") + device_count = 0 if torch.cuda.is_available(): device = torch.device("cuda") + device_count = torch.cuda.device_count() + dtype = torch.float16 if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + device_count = torch.xpu.device_count() dtype = torch.float16 if dtype is None else dtype else: if quantize: @@ -616,18 +621,17 @@ class CausalLM(Model): torch_dtype=dtype, device_map=( "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 + if device_count > 1 else None ), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) if ( - torch.cuda.is_available() - and torch.cuda.device_count() == 1 + device_count == 1 and quantize != "bitsandbytes" ): - model = model.cuda() + model = model.to(device) if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 04d4c28b..94f87d02 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -558,14 +558,13 @@ class Seq2SeqLM(Model): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = default_dtype if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype @@ -630,8 +629,14 @@ class Seq2SeqLM(Model): if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") + device_count = 0 if torch.cuda.is_available(): device = torch.device("cuda") + device_count = torch.cuda.device_count() + dtype = torch.float16 if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + device_count = torch.xpu.device_count() dtype = torch.float16 if dtype is None else dtype else: if quantize: @@ -646,14 +651,14 @@ class Seq2SeqLM(Model): torch_dtype=dtype, device_map=( "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 + if device_count > 1 else None ), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() + if device_count == 1: + model = model.to(device) tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 782b4f15..b693258c 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -66,6 +66,11 @@ elif is_ipex_available(): empty_cache = noop synchronize = noop get_free_memory = get_cpu_free_memory +elif hasattr(torch, "xpu") and torch.xpu.is_available(): + SYSTEM = "xpu" + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory else: SYSTEM = "cpu"