mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add NPU support
This commit is contained in:
parent
00b8f36fba
commit
d0463ce151
@ -4,6 +4,8 @@ from typing import Dict, Optional, TypeVar
|
||||
|
||||
from text_generation_server.models.types import Batch
|
||||
|
||||
from text_generation_server.utils import is_torch_npu_available
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
||||
@ -24,6 +26,8 @@ class Cache:
|
||||
del batch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def clear(self):
|
||||
keys = list(self.cache.keys())
|
||||
|
@ -6,6 +6,7 @@ from grpc_status import rpc_status
|
||||
from grpc_interceptor.server import AsyncServerInterceptor
|
||||
from loguru import logger
|
||||
from typing import Callable, Any
|
||||
from text_generation_server.utils import is_torch_npu_available
|
||||
|
||||
|
||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
|
||||
await context.abort_with_status(
|
||||
rpc_status.to_status(
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
|
||||
@ -49,6 +50,9 @@ class BLOOMSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -16,7 +16,12 @@ from text_generation_server.models.types import (
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
Sampling,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -487,6 +492,9 @@ class CausalLM(Model):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -513,6 +521,8 @@ class CausalLM(Model):
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
elif is_torch_npu_available():
|
||||
model = model.npu()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||
@ -172,6 +173,9 @@ class GalacticaSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
|
||||
@ -31,6 +32,9 @@ class GPTNeoxSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -22,6 +22,7 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
|
||||
@ -40,6 +41,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
# 9b seems to work correctly enough in float16, but 80b seems
|
||||
# to be really saturating for f16.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -25,7 +25,12 @@ from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
Sampling,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
import re
|
||||
|
||||
@ -584,6 +589,9 @@ class IdeficsCausalLM(Model):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -617,6 +625,8 @@ class IdeficsCausalLM(Model):
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
elif is_torch_npu_available():
|
||||
model = model.npu()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
|
@ -18,6 +18,7 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -50,6 +51,9 @@ class MPTSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -13,6 +13,7 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
|
||||
@ -29,6 +30,9 @@ class OPTSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils import is_torch_npu_available
|
||||
|
||||
|
||||
class RW(CausalLM):
|
||||
@ -18,6 +19,9 @@ class RW(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -44,6 +48,8 @@ class RW(CausalLM):
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
elif is_torch_npu_available():
|
||||
model = model.npu()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional, List
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils import is_torch_npu_available
|
||||
|
||||
FIM_PREFIX = "<fim-prefix>"
|
||||
FIM_MIDDLE = "<fim-middle>"
|
||||
@ -25,6 +26,9 @@ class SantaCoder(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
@ -15,7 +15,12 @@ from text_generation_server.models.types import (
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
Sampling,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -536,6 +541,9 @@ class Seq2SeqLM(Model):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -555,6 +563,8 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
elif is_torch_npu_available():
|
||||
model = model.npu()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -16,6 +16,7 @@ from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
|
||||
@ -32,6 +33,9 @@ class T5Sharded(Seq2SeqLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import (
|
||||
Sampling,
|
||||
Greedy,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import is_torch_npu_available
|
||||
|
||||
__all__ = [
|
||||
"convert_file",
|
||||
@ -39,4 +40,5 @@ __all__ = [
|
||||
"StopSequenceCriteria",
|
||||
"FinishReason",
|
||||
"Weights",
|
||||
"is_torch_npu_available",
|
||||
]
|
||||
|
18
server/text_generation_server/utils/import_utils.py
Normal file
18
server/text_generation_server/utils/import_utils.py
Normal file
@ -0,0 +1,18 @@
|
||||
import importlib.util
|
||||
|
||||
def is_torch_npu_available(check_device=False):
|
||||
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
|
||||
if importlib.util.find_spec("torch_npu") is None:
|
||||
return False
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
if check_device:
|
||||
try:
|
||||
# Will raise a RuntimeError if no NPU is found
|
||||
_ = torch.npu.device_count()
|
||||
return torch.npu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
Loading…
Reference in New Issue
Block a user