diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index 9dbc83f7..c55f2a07 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -7,14 +7,17 @@ pub(crate) struct Env { git_sha: &'static str, docker_label: &'static str, nvidia_env: String, + npu_env: String, } impl Env { pub fn new() -> Self { let nvidia_env = nvidia_smi(); + let npu_env = npu_smi(); Self { nvidia_env: nvidia_env.unwrap_or("N/A".to_string()), + npu_env: npu_env.unwrap_or("N/A".to_string()), cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"), cargo_version: env!("VERGEN_RUSTC_SEMVER"), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), @@ -31,7 +34,8 @@ impl fmt::Display for Env { writeln!(f, "Cargo version: {}", self.cargo_version)?; writeln!(f, "Commit sha: {}", self.git_sha)?; writeln!(f, "Docker label: {}", self.docker_label)?; - write!(f, "nvidia-smi:\n{}", self.nvidia_env)?; + writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?; + write!(f, "npu-smi:\n{}", self.npu_env)?; Ok(()) } @@ -43,3 +47,10 @@ fn nvidia_smi() -> Option { let output = nvidia_smi.replace('\n', "\n "); Some(output.trim().to_string()) } + +fn npu_smi() -> Option { + let output = Command::new("npu-smi info").output().ok()?; + let npu_smi = String::from_utf8(output.stdout).ok()?; + let output = npu_smi.replace('\n', "\n "); + Some(output.trim().to_string()) +} diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 4504733e..6eb16096 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -3,6 +3,7 @@ import torch from typing import Dict, Optional, TypeVar from text_generation_server.models.types import Batch +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM B = TypeVar("B", bound=Batch) @@ -24,6 +25,8 @@ class Cache: del batch if torch.cuda.is_available(): torch.cuda.empty_cache() + elif IS_NPU_SYSTEM: + torch.npu.empty_cache() def clear(self): keys = list(self.cache.keys()) diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 725105f3..1458b424 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -6,6 +6,7 @@ from grpc_status import rpc_status from grpc_interceptor.server import AsyncServerInterceptor from loguru import logger from typing import Callable, Any +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class ExceptionInterceptor(AsyncServerInterceptor): @@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor): if torch.cuda.is_available(): torch.cuda.empty_cache() + elif IS_NPU_SYSTEM: + torch.npu.empty_cache() await context.abort_with_status( rpc_status.to_status( diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 67129ec3..bd08fa67 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class BloomCausalLMBatch(CausalLMBatch): @@ -50,6 +51,9 @@ class BLOOMSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 2e6ae086..b0403635 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -2,6 +2,7 @@ import math import torch from typing import Optional, List, Tuple +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_NPU_SYSTEM BLOCK_SIZE: int = 16 # Will be set in warmup @@ -119,7 +120,10 @@ def set_cache_manager( global CACHE_MANAGER if CACHE_MANAGER is not None: del CACHE_MANAGER - torch.cuda.empty_cache() + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + torch.cuda.empty_cache() + elif IS_NPU_SYSTEM: + torch.npu.empty_cache() CACHE_MANAGER = CacheManager( num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 93ec6ba4..bdaf119c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -16,6 +16,7 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -492,6 +493,9 @@ class CausalLM(Model): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -506,15 +510,19 @@ class CausalLM(Model): truncation_side="left", trust_remote_code=trust_remote_code, ) + + if ( + torch.cuda.is_available() and torch.cuda.device_count() > 1 + or IS_NPU_SYSTEM and torch.npu.device_count() > 1 + ): + device_map = "auto" + else: + device_map = None model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), + device_map=device_map, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) @@ -524,6 +532,12 @@ class CausalLM(Model): and quantize != "bitsandbytes" ): model = model.cuda() + if ( + IS_NPU_SYSTEM + and torch.npu.device_count() == 1 + and quantize != "bitsandbytes" + ): + model = model.npu() if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a46f86be..d0f3cb5f 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py @@ -175,6 +176,9 @@ class GalacticaSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 3b8a70bc..13c90f38 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,6 +1,6 @@ import torch import os -MEM_POOL = torch.cuda.graph_pool_handle() +MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"} diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 1c4cfe7d..862b087d 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -16,6 +16,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class GPTNeoxSharded(CausalLM): @@ -32,6 +33,9 @@ class GPTNeoxSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 30bf4aa6..295e0abb 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -23,6 +23,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class IDEFICSSharded(IdeficsCausalLM): @@ -41,6 +42,9 @@ class IDEFICSSharded(IdeficsCausalLM): # 9b seems to work correctly enough in float16, but 80b seems # to be really saturating for f16. dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index c96e8152..10b704c1 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -20,6 +20,7 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM import re @@ -580,6 +581,9 @@ class IdeficsCausalLM(Model): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device("npu") + dtype = torch.bfloat16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 2500d454..94684421 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -30,6 +30,7 @@ from text_generation_server.models.types import ( from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM def new_inference_params( @@ -422,6 +423,9 @@ class Mamba(Model): # differences while the server is under load. # This is detectable by the integration load test dtype = torch.bfloat16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device("npu") + dtype = torch.bfloat16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 6b3f29a6..0c1d9704 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -19,6 +19,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -51,6 +52,9 @@ class MPTSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 703e5b58..72abb9a9 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class OPTSharded(CausalLM): @@ -30,6 +31,9 @@ class OPTSharded(CausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index cc4e2505..c92ba193 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -14,6 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class Phi(CausalLM): @@ -30,6 +31,9 @@ class Phi(CausalLM): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 92c93542..ef54801c 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Optional, Tuple from text_generation_server.models import CausalLM +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class RW(CausalLM): @@ -22,6 +23,9 @@ class RW(CausalLM): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -36,20 +40,25 @@ class RW(CausalLM): truncation_side="left", trust_remote_code=trust_remote_code, ) + if ( + torch.cuda.is_available() and torch.cuda.device_count() > 1 + or IS_NPU_SYSTEM and torch.npu.device_count() > 1 + ): + device_map = "auto" + else: + device_map = None model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), + device_map=device_map, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() + if IS_NPU_SYSTEM and torch.npu.device_count() == 1: + model = model.npu() if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 73c21cce..cf6fea9e 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -5,6 +5,7 @@ from typing import Optional, List from transformers import AutoTokenizer, AutoModelForCausalLM from text_generation_server.models import CausalLM +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM FIM_PREFIX = "" FIM_MIDDLE = "" @@ -26,6 +27,9 @@ class SantaCoder(CausalLM): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e55a661c..ddf9028e 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -16,6 +16,7 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -542,6 +543,9 @@ class Seq2SeqLM(Model): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device("npu") + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -549,20 +553,25 @@ class Seq2SeqLM(Model): device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype + if ( + torch.cuda.is_available() and torch.cuda.device_count() > 1 + or IS_NPU_SYSTEM and torch.npu.is_available() > 1 + ): + device_map = "auto" + else: + device_map = None model = AutoModelForSeq2SeqLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), + device_map=device_map, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() + if IS_NPU_SYSTEM and torch.npu.device_count() == 1: + model = model.npu() tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 3f3cb965..916bd9b5 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -17,6 +17,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM class T5Sharded(Seq2SeqLM): @@ -33,6 +34,9 @@ class T5Sharded(Seq2SeqLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_NPU_SYSTEM: + device = torch.device(f"npu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index d02bfc5b..27e9e8fd 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,12 +3,14 @@ import torch from datetime import timedelta from loguru import logger +from text_generation_server.utils.import_utils import IS_NPU_SYSTEM # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) # CUDA memory fraction +# TODO: Do we need to rename CUDA_MEMORY_FRACTION to DEVICE_MEMORY_FRACTION? MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) @@ -56,6 +58,15 @@ def initialize_torch_distributed(): options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) + elif IS_NPU_SYSTEM: + assert WORLD_SIZE <= torch.npu.device_count(), "Each process is one npu" + device = RANK % torch.npu.device_count() + torch.npu.set_device(device) + torch.npu.set_per_process_memory_fraction(MEMORY_FRACTION, device) + backend = "hccl" + options = ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + options._timeout = timedelta(seconds=60) else: backend = "gloo" options = None diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 428c9f3e..ce09279f 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,4 +1,15 @@ import torch + +def is_npu_available(): + try: + import torch_npu # noqa: F401 + except ImportError: + return False + + return hasattr(torch, "npu") and torch.npu.is_available() + + IS_ROCM_SYSTEM = torch.version.hip is not None IS_CUDA_SYSTEM = torch.version.cuda is not None +IS_NPU_SYSTEM = is_npu_available()