add ascend npu support for TGI

This commit is contained in:
statelesshz 2024-04-14 16:11:10 +08:00
parent 7dbaf9e901
commit 6655717e19
21 changed files with 137 additions and 18 deletions

View File

@ -7,14 +7,17 @@ pub(crate) struct Env {
git_sha: &'static str,
docker_label: &'static str,
nvidia_env: String,
npu_env: String,
}
impl Env {
pub fn new() -> Self {
let nvidia_env = nvidia_smi();
let npu_env = npu_smi();
Self {
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
npu_env: npu_env.unwrap_or("N/A".to_string()),
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
@ -31,7 +34,8 @@ impl fmt::Display for Env {
writeln!(f, "Cargo version: {}", self.cargo_version)?;
writeln!(f, "Commit sha: {}", self.git_sha)?;
writeln!(f, "Docker label: {}", self.docker_label)?;
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
write!(f, "npu-smi:\n{}", self.npu_env)?;
Ok(())
}
@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
let output = nvidia_smi.replace('\n', "\n ");
Some(output.trim().to_string())
}
fn npu_smi() -> Option<String> {
let output = Command::new("npu-smi info").output().ok()?;
let npu_smi = String::from_utf8(output.stdout).ok()?;
let output = npu_smi.replace('\n', "\n ");
Some(output.trim().to_string())
}

View File

@ -3,6 +3,7 @@ import torch
from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
B = TypeVar("B", bound=Batch)
@ -24,6 +25,8 @@ class Cache:
del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif IS_NPU_SYSTEM:
torch.npu.empty_cache()
def clear(self):
keys = list(self.cache.keys())

View File

@ -6,6 +6,7 @@ from grpc_status import rpc_status
from grpc_interceptor.server import AsyncServerInterceptor
from loguru import logger
from typing import Callable, Any
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class ExceptionInterceptor(AsyncServerInterceptor):
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif IS_NPU_SYSTEM:
torch.npu.empty_cache()
await context.abort_with_status(
rpc_status.to_status(

View File

@ -20,6 +20,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class BloomCausalLMBatch(CausalLMBatch):
@ -50,6 +51,9 @@ class BLOOMSharded(CausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype

View File

@ -2,6 +2,7 @@ import math
import torch
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_NPU_SYSTEM
BLOCK_SIZE: int = 16
# Will be set in warmup
@ -119,7 +120,10 @@ def set_cache_manager(
global CACHE_MANAGER
if CACHE_MANAGER is not None:
del CACHE_MANAGER
torch.cuda.empty_cache()
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache()
elif IS_NPU_SYSTEM:
torch.npu.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device

View File

@ -16,6 +16,7 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
tracer = trace.get_tracer(__name__)
@ -492,6 +493,9 @@ class CausalLM(Model):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
@ -506,15 +510,19 @@ class CausalLM(Model):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
if (
torch.cuda.is_available() and torch.cuda.device_count() > 1
or IS_NPU_SYSTEM and torch.npu.device_count() > 1
):
device_map = "auto"
else:
device_map = None
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
device_map=device_map,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
@ -524,6 +532,12 @@ class CausalLM(Model):
and quantize != "bitsandbytes"
):
model = model.cuda()
if (
IS_NPU_SYSTEM
and torch.npu.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.npu()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:

View File

@ -20,6 +20,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
@ -175,6 +176,9 @@ class GalacticaSharded(CausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype

View File

@ -1,6 +1,6 @@
import torch
import os
MEM_POOL = torch.cuda.graph_pool_handle()
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}

View File

@ -16,6 +16,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class GPTNeoxSharded(CausalLM):
@ -32,6 +33,9 @@ class GPTNeoxSharded(CausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype

View File

@ -23,6 +23,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class IDEFICSSharded(IdeficsCausalLM):
@ -41,6 +42,9 @@ class IDEFICSSharded(IdeficsCausalLM):
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype

View File

@ -20,6 +20,7 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
import re
@ -580,6 +581,9 @@ class IdeficsCausalLM(Model):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device("npu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -30,6 +30,7 @@ from text_generation_server.models.types import (
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
def new_inference_params(
@ -422,6 +423,9 @@ class Mamba(Model):
# differences while the server is under load.
# This is detectable by the integration load test
dtype = torch.bfloat16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device("npu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -19,6 +19,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
tracer = trace.get_tracer(__name__)
@ -51,6 +52,9 @@ class MPTSharded(CausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype

View File

@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class OPTSharded(CausalLM):
@ -30,6 +31,9 @@ class OPTSharded(CausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype

View File

@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class Phi(CausalLM):
@ -30,6 +31,9 @@ class Phi(CausalLM):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class RW(CausalLM):
@ -22,6 +23,9 @@ class RW(CausalLM):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
@ -36,20 +40,25 @@ class RW(CausalLM):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
if (
torch.cuda.is_available() and torch.cuda.device_count() > 1
or IS_NPU_SYSTEM and torch.npu.device_count() > 1
):
device_map = "auto"
else:
device_map = None
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
device_map=device_map,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if IS_NPU_SYSTEM and torch.npu.device_count() == 1:
model = model.npu()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:

View File

@ -5,6 +5,7 @@ from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation_server.models import CausalLM
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
@ -26,6 +27,9 @@ class SantaCoder(CausalLM):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -16,6 +16,7 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
tracer = trace.get_tracer(__name__)
@ -542,6 +543,9 @@ class Seq2SeqLM(Model):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device("npu")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
@ -549,20 +553,25 @@ class Seq2SeqLM(Model):
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
if (
torch.cuda.is_available() and torch.cuda.device_count() > 1
or IS_NPU_SYSTEM and torch.npu.is_available() > 1
):
device_map = "auto"
else:
device_map = None
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
device_map=device_map,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if IS_NPU_SYSTEM and torch.npu.device_count() == 1:
model = model.npu()
tokenizer = AutoTokenizer.from_pretrained(
model_id,

View File

@ -17,6 +17,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class T5Sharded(Seq2SeqLM):
@ -33,6 +34,9 @@ class T5Sharded(Seq2SeqLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_NPU_SYSTEM:
device = torch.device(f"npu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype

View File

@ -3,12 +3,14 @@ import torch
from datetime import timedelta
from loguru import logger
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
# CUDA memory fraction
# TODO: Do we need to rename CUDA_MEMORY_FRACTION to DEVICE_MEMORY_FRACTION?
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
@ -56,6 +58,15 @@ def initialize_torch_distributed():
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
elif IS_NPU_SYSTEM:
assert WORLD_SIZE <= torch.npu.device_count(), "Each process is one npu"
device = RANK % torch.npu.device_count()
torch.npu.set_device(device)
torch.npu.set_per_process_memory_fraction(MEMORY_FRACTION, device)
backend = "hccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None

View File

@ -1,4 +1,15 @@
import torch
def is_npu_available():
try:
import torch_npu # noqa: F401
except ImportError:
return False
return hasattr(torch, "npu") and torch.npu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_NPU_SYSTEM = is_npu_available()