add NPU support

This commit is contained in:
zhangsibo1129 2023-10-09 15:56:29 +08:00
parent 00b8f36fba
commit d0463ce151
16 changed files with 98 additions and 3 deletions

View File

@ -4,6 +4,8 @@ from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch from text_generation_server.models.types import Batch
from text_generation_server.utils import is_torch_npu_available
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
@ -24,6 +26,8 @@ class Cache:
del batch del batch
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
if is_torch_npu_available():
torch.npu.empty_cache()
def clear(self): def clear(self):
keys = list(self.cache.keys()) keys = list(self.cache.keys())

View File

@ -6,6 +6,7 @@ from grpc_status import rpc_status
from grpc_interceptor.server import AsyncServerInterceptor from grpc_interceptor.server import AsyncServerInterceptor
from loguru import logger from loguru import logger
from typing import Callable, Any from typing import Callable, Any
from text_generation_server.utils import is_torch_npu_available
class ExceptionInterceptor(AsyncServerInterceptor): class ExceptionInterceptor(AsyncServerInterceptor):
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
if is_torch_npu_available():
torch.npu.empty_cache()
await context.abort_with_status( await context.abort_with_status(
rpc_status.to_status( rpc_status.to_status(

View File

@ -19,6 +19,7 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
is_torch_npu_available,
) )
@ -49,6 +50,9 @@ class BLOOMSharded(CausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype

View File

@ -16,7 +16,12 @@ from text_generation_server.models.types import (
TopTokens, TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
is_torch_npu_available,
)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -487,6 +492,9 @@ class CausalLM(Model):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
@ -513,6 +521,8 @@ class CausalLM(Model):
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() model = model.cuda()
elif is_torch_npu_available():
model = model.npu()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:

View File

@ -19,6 +19,7 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
is_torch_npu_available,
) )
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
@ -172,6 +173,9 @@ class GalacticaSharded(CausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype

View File

@ -15,6 +15,7 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
is_torch_npu_available,
) )
@ -31,6 +32,9 @@ class GPTNeoxSharded(CausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype

View File

@ -22,6 +22,7 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
is_torch_npu_available,
) )
@ -40,6 +41,9 @@ class IDEFICSSharded(IdeficsCausalLM):
# 9b seems to work correctly enough in float16, but 80b seems # 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16. # to be really saturating for f16.
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device(f"npu:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype

View File

@ -25,7 +25,12 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
is_torch_npu_available,
)
import re import re
@ -584,6 +589,9 @@ class IdeficsCausalLM(Model):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
@ -617,6 +625,8 @@ class IdeficsCausalLM(Model):
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() model = model.cuda()
elif is_torch_npu_available():
model = model.npu()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:

View File

@ -18,6 +18,7 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
is_torch_npu_available,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -50,6 +51,9 @@ class MPTSharded(CausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype

View File

@ -13,6 +13,7 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
is_torch_npu_available,
) )
@ -29,6 +30,9 @@ class OPTSharded(CausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype

View File

@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.utils import is_torch_npu_available
class RW(CausalLM): class RW(CausalLM):
@ -18,6 +19,9 @@ class RW(CausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
@ -44,6 +48,8 @@ class RW(CausalLM):
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() model = model.cuda()
elif is_torch_npu_available():
model = model.npu()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:

View File

@ -5,6 +5,7 @@ from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.utils import is_torch_npu_available
FIM_PREFIX = "<fim-prefix>" FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>" FIM_MIDDLE = "<fim-middle>"
@ -25,6 +26,9 @@ class SantaCoder(CausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -15,7 +15,12 @@ from text_generation_server.models.types import (
TopTokens, TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
is_torch_npu_available,
)
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -536,6 +541,9 @@ class Seq2SeqLM(Model):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
@ -555,6 +563,8 @@ class Seq2SeqLM(Model):
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() model = model.cuda()
elif is_torch_npu_available():
model = model.npu()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -16,6 +16,7 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
is_torch_npu_available,
) )
@ -32,6 +33,9 @@ class T5Sharded(Seq2SeqLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif is_torch_npu_available():
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype

View File

@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import (
Sampling, Sampling,
Greedy, Greedy,
) )
from text_generation_server.utils.import_utils import is_torch_npu_available
__all__ = [ __all__ = [
"convert_file", "convert_file",
@ -39,4 +40,5 @@ __all__ = [
"StopSequenceCriteria", "StopSequenceCriteria",
"FinishReason", "FinishReason",
"Weights", "Weights",
"is_torch_npu_available",
] ]

View File

@ -0,0 +1,18 @@
import importlib.util
def is_torch_npu_available(check_device=False):
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
if importlib.util.find_spec("torch_npu") is None:
return False
import torch
import torch_npu
if check_device:
try:
# Will raise a RuntimeError if no NPU is found
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
return hasattr(torch, "npu") and torch.npu.is_available()