From d0463ce151d06dee1d3314faf4e018f57579271e Mon Sep 17 00:00:00 2001 From: zhangsibo1129 Date: Mon, 9 Oct 2023 15:56:29 +0800 Subject: [PATCH] add NPU support --- server/text_generation_server/cache.py | 4 ++++ server/text_generation_server/interceptor.py | 3 +++ server/text_generation_server/models/bloom.py | 4 ++++ .../text_generation_server/models/causal_lm.py | 12 +++++++++++- .../text_generation_server/models/galactica.py | 4 ++++ .../text_generation_server/models/gpt_neox.py | 4 ++++ .../text_generation_server/models/idefics.py | 4 ++++ .../models/idefics_causal_lm.py | 12 +++++++++++- server/text_generation_server/models/mpt.py | 4 ++++ server/text_generation_server/models/opt.py | 4 ++++ server/text_generation_server/models/rw.py | 6 ++++++ .../models/santacoder.py | 4 ++++ .../models/seq2seq_lm.py | 12 +++++++++++- server/text_generation_server/models/t5.py | 4 ++++ .../text_generation_server/utils/__init__.py | 2 ++ .../utils/import_utils.py | 18 ++++++++++++++++++ 16 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 server/text_generation_server/utils/import_utils.py diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 4504733e..7b59d445 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -4,6 +4,8 @@ from typing import Dict, Optional, TypeVar from text_generation_server.models.types import Batch +from text_generation_server.utils import is_torch_npu_available + B = TypeVar("B", bound=Batch) @@ -24,6 +26,8 @@ class Cache: del batch if torch.cuda.is_available(): torch.cuda.empty_cache() + if is_torch_npu_available(): + torch.npu.empty_cache() def clear(self): keys = list(self.cache.keys()) diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 725105f3..e3314401 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -6,6 +6,7 @@ from grpc_status import rpc_status from grpc_interceptor.server import AsyncServerInterceptor from loguru import logger from typing import Callable, Any +from text_generation_server.utils import is_torch_npu_available class ExceptionInterceptor(AsyncServerInterceptor): @@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor): if torch.cuda.is_available(): torch.cuda.empty_cache() + if is_torch_npu_available(): + torch.npu.empty_cache() await context.abort_with_status( rpc_status.to_status( diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 8e8daad3..4cbe90fb 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + is_torch_npu_available, ) @@ -49,6 +50,9 @@ class BLOOMSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index fccfb0f8..c5fbebbf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -16,7 +16,12 @@ from text_generation_server.models.types import ( TopTokens, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, + Sampling, + is_torch_npu_available, +) tracer = trace.get_tracer(__name__) @@ -487,6 +492,9 @@ class CausalLM(Model): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -513,6 +521,8 @@ class CausalLM(Model): ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() + elif is_torch_npu_available(): + model = model.npu() if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index b296c96e..5d7f3964 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + is_torch_npu_available, ) # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py @@ -172,6 +173,9 @@ class GalacticaSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index d4c64dfe..9dc4b050 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -15,6 +15,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + is_torch_npu_available, ) @@ -31,6 +32,9 @@ class GPTNeoxSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index fa23d1f9..5455761d 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -22,6 +22,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + is_torch_npu_available, ) @@ -40,6 +41,9 @@ class IDEFICSSharded(IdeficsCausalLM): # 9b seems to work correctly enough in float16, but 80b seems # to be really saturating for f16. dtype = torch.bfloat16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device(f"npu:{rank}") + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2472caf6..a75fc7b5 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -25,7 +25,12 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, + Sampling, + is_torch_npu_available, +) import re @@ -584,6 +589,9 @@ class IdeficsCausalLM(Model): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -617,6 +625,8 @@ class IdeficsCausalLM(Model): ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() + elif is_torch_npu_available(): + model = model.npu() if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 19de497c..3bbaff9e 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -18,6 +18,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + is_torch_npu_available, ) tracer = trace.get_tracer(__name__) @@ -50,6 +51,9 @@ class MPTSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index b2b87246..bd1834ac 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -13,6 +13,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + is_torch_npu_available, ) @@ -29,6 +30,9 @@ class OPTSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 802a4aa6..84ee131b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Optional, Tuple from text_generation_server.models import CausalLM +from text_generation_server.utils import is_torch_npu_available class RW(CausalLM): @@ -18,6 +19,9 @@ class RW(CausalLM): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -44,6 +48,8 @@ class RW(CausalLM): ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() + elif is_torch_npu_available(): + model = model.npu() if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 7b269d8e..2e20ca2e 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -5,6 +5,7 @@ from typing import Optional, List from transformers import AutoTokenizer, AutoModelForCausalLM from text_generation_server.models import CausalLM +from text_generation_server.utils import is_torch_npu_available FIM_PREFIX = "" FIM_MIDDLE = "" @@ -25,6 +26,9 @@ class SantaCoder(CausalLM): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d4d3cd19..0ae987dc 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -15,7 +15,12 @@ from text_generation_server.models.types import ( TopTokens, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, + Sampling, + is_torch_npu_available, +) tracer = trace.get_tracer(__name__) @@ -536,6 +541,9 @@ class Seq2SeqLM(Model): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -555,6 +563,8 @@ class Seq2SeqLM(Model): ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() + elif is_torch_npu_available(): + model = model.npu() tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 161e69ba..eb767292 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -16,6 +16,7 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + is_torch_npu_available, ) @@ -32,6 +33,9 @@ class T5Sharded(Seq2SeqLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif is_torch_npu_available(): + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 08ba808d..2b5eb251 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import ( Sampling, Greedy, ) +from text_generation_server.utils.import_utils import is_torch_npu_available __all__ = [ "convert_file", @@ -39,4 +40,5 @@ __all__ = [ "StopSequenceCriteria", "FinishReason", "Weights", + "is_torch_npu_available", ] diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py new file mode 100644 index 00000000..db382d77 --- /dev/null +++ b/server/text_generation_server/utils/import_utils.py @@ -0,0 +1,18 @@ +import importlib.util + +def is_torch_npu_available(check_device=False): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available()