mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
add NPU support
This commit is contained in:
parent
00b8f36fba
commit
d0463ce151
@ -4,6 +4,8 @@ from typing import Dict, Optional, TypeVar
|
|||||||
|
|
||||||
from text_generation_server.models.types import Batch
|
from text_generation_server.models.types import Batch
|
||||||
|
|
||||||
|
from text_generation_server.utils import is_torch_npu_available
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +26,8 @@ class Cache:
|
|||||||
del batch
|
del batch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
if is_torch_npu_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
keys = list(self.cache.keys())
|
keys = list(self.cache.keys())
|
||||||
|
@ -6,6 +6,7 @@ from grpc_status import rpc_status
|
|||||||
from grpc_interceptor.server import AsyncServerInterceptor
|
from grpc_interceptor.server import AsyncServerInterceptor
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any
|
||||||
|
from text_generation_server.utils import is_torch_npu_available
|
||||||
|
|
||||||
|
|
||||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
if is_torch_npu_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
await context.abort_with_status(
|
await context.abort_with_status(
|
||||||
rpc_status.to_status(
|
rpc_status.to_status(
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -49,6 +50,9 @@ class BLOOMSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -16,7 +16,12 @@ from text_generation_server.models.types import (
|
|||||||
TopTokens,
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import (
|
||||||
|
NextTokenChooser,
|
||||||
|
StoppingCriteria,
|
||||||
|
Sampling,
|
||||||
|
is_torch_npu_available,
|
||||||
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -487,6 +492,9 @@ class CausalLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
@ -513,6 +521,8 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||||
@ -172,6 +173,9 @@ class GalacticaSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +32,9 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -22,6 +22,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -40,6 +41,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
# 9b seems to work correctly enough in float16, but 80b seems
|
# 9b seems to work correctly enough in float16, but 80b seems
|
||||||
# to be really saturating for f16.
|
# to be really saturating for f16.
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -25,7 +25,12 @@ from text_generation_server.models.types import (
|
|||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import (
|
||||||
|
NextTokenChooser,
|
||||||
|
StoppingCriteria,
|
||||||
|
Sampling,
|
||||||
|
is_torch_npu_available,
|
||||||
|
)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -584,6 +589,9 @@ class IdeficsCausalLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
@ -617,6 +625,8 @@ class IdeficsCausalLM(Model):
|
|||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
@ -18,6 +18,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -50,6 +51,9 @@ class MPTSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -13,6 +13,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -29,6 +30,9 @@ class OPTSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.utils import is_torch_npu_available
|
||||||
|
|
||||||
|
|
||||||
class RW(CausalLM):
|
class RW(CausalLM):
|
||||||
@ -18,6 +19,9 @@ class RW(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
@ -44,6 +48,8 @@ class RW(CausalLM):
|
|||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional, List
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.utils import is_torch_npu_available
|
||||||
|
|
||||||
FIM_PREFIX = "<fim-prefix>"
|
FIM_PREFIX = "<fim-prefix>"
|
||||||
FIM_MIDDLE = "<fim-middle>"
|
FIM_MIDDLE = "<fim-middle>"
|
||||||
@ -25,6 +26,9 @@ class SantaCoder(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -15,7 +15,12 @@ from text_generation_server.models.types import (
|
|||||||
TopTokens,
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import (
|
||||||
|
NextTokenChooser,
|
||||||
|
StoppingCriteria,
|
||||||
|
Sampling,
|
||||||
|
is_torch_npu_available,
|
||||||
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -536,6 +541,9 @@ class Seq2SeqLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
@ -555,6 +563,8 @@ class Seq2SeqLM(Model):
|
|||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.utils import (
|
|||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -32,6 +33,9 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import (
|
|||||||
Sampling,
|
Sampling,
|
||||||
Greedy,
|
Greedy,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_file",
|
"convert_file",
|
||||||
@ -39,4 +40,5 @@ __all__ = [
|
|||||||
"StopSequenceCriteria",
|
"StopSequenceCriteria",
|
||||||
"FinishReason",
|
"FinishReason",
|
||||||
"Weights",
|
"Weights",
|
||||||
|
"is_torch_npu_available",
|
||||||
]
|
]
|
||||||
|
18
server/text_generation_server/utils/import_utils.py
Normal file
18
server/text_generation_server/utils/import_utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import importlib.util
|
||||||
|
|
||||||
|
def is_torch_npu_available(check_device=False):
|
||||||
|
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
|
||||||
|
if importlib.util.find_spec("torch_npu") is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
if check_device:
|
||||||
|
try:
|
||||||
|
# Will raise a RuntimeError if no NPU is found
|
||||||
|
_ = torch.npu.device_count()
|
||||||
|
return torch.npu.is_available()
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
return hasattr(torch, "npu") and torch.npu.is_available()
|
Loading…
Reference in New Issue
Block a user