mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add ascend npu support for TGI
This commit is contained in:
parent
7dbaf9e901
commit
6655717e19
@ -7,14 +7,17 @@ pub(crate) struct Env {
|
||||
git_sha: &'static str,
|
||||
docker_label: &'static str,
|
||||
nvidia_env: String,
|
||||
npu_env: String,
|
||||
}
|
||||
|
||||
impl Env {
|
||||
pub fn new() -> Self {
|
||||
let nvidia_env = nvidia_smi();
|
||||
let npu_env = npu_smi();
|
||||
|
||||
Self {
|
||||
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
||||
npu_env: npu_env.unwrap_or("N/A".to_string()),
|
||||
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
||||
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||
@ -31,7 +34,8 @@ impl fmt::Display for Env {
|
||||
writeln!(f, "Cargo version: {}", self.cargo_version)?;
|
||||
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
||||
writeln!(f, "Docker label: {}", self.docker_label)?;
|
||||
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||
write!(f, "npu-smi:\n{}", self.npu_env)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
|
||||
let output = nvidia_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
||||
fn npu_smi() -> Option<String> {
|
||||
let output = Command::new("npu-smi info").output().ok()?;
|
||||
let npu_smi = String::from_utf8(output.stdout).ok()?;
|
||||
let output = npu_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
from typing import Dict, Optional, TypeVar
|
||||
|
||||
from text_generation_server.models.types import Batch
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
@ -24,6 +25,8 @@ class Cache:
|
||||
del batch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_NPU_SYSTEM:
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def clear(self):
|
||||
keys = list(self.cache.keys())
|
||||
|
@ -6,6 +6,7 @@ from grpc_status import rpc_status
|
||||
from grpc_interceptor.server import AsyncServerInterceptor
|
||||
from loguru import logger
|
||||
from typing import Callable, Any
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_NPU_SYSTEM:
|
||||
torch.npu.empty_cache()
|
||||
|
||||
await context.abort_with_status(
|
||||
rpc_status.to_status(
|
||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class BloomCausalLMBatch(CausalLMBatch):
|
||||
@ -50,6 +51,9 @@ class BLOOMSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -2,6 +2,7 @@ import math
|
||||
import torch
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_NPU_SYSTEM
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
# Will be set in warmup
|
||||
@ -119,7 +120,10 @@ def set_cache_manager(
|
||||
global CACHE_MANAGER
|
||||
if CACHE_MANAGER is not None:
|
||||
del CACHE_MANAGER
|
||||
torch.cuda.empty_cache()
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_NPU_SYSTEM:
|
||||
torch.npu.empty_cache()
|
||||
|
||||
CACHE_MANAGER = CacheManager(
|
||||
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
||||
|
@ -16,6 +16,7 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -492,6 +493,9 @@ class CausalLM(Model):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -506,15 +510,19 @@ class CausalLM(Model):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if (
|
||||
torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
or IS_NPU_SYSTEM and torch.npu.device_count() > 1
|
||||
):
|
||||
device_map = "auto"
|
||||
else:
|
||||
device_map = None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
device_map=device_map,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -524,6 +532,12 @@ class CausalLM(Model):
|
||||
and quantize != "bitsandbytes"
|
||||
):
|
||||
model = model.cuda()
|
||||
if (
|
||||
IS_NPU_SYSTEM
|
||||
and torch.npu.device_count() == 1
|
||||
and quantize != "bitsandbytes"
|
||||
):
|
||||
model = model.npu()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||
|
||||
@ -175,6 +176,9 @@ class GalacticaSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
||||
|
@ -16,6 +16,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class GPTNeoxSharded(CausalLM):
|
||||
@ -32,6 +33,9 @@ class GPTNeoxSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -23,6 +23,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class IDEFICSSharded(IdeficsCausalLM):
|
||||
@ -41,6 +42,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
# 9b seems to work correctly enough in float16, but 80b seems
|
||||
# to be really saturating for f16.
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -20,6 +20,7 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
import re
|
||||
|
||||
@ -580,6 +581,9 @@ class IdeficsCausalLM(Model):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device("npu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.models.types import (
|
||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
def new_inference_params(
|
||||
@ -422,6 +423,9 @@ class Mamba(Model):
|
||||
# differences while the server is under load.
|
||||
# This is detectable by the integration load test
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device("npu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -51,6 +52,9 @@ class MPTSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class OPTSharded(CausalLM):
|
||||
@ -30,6 +31,9 @@ class OPTSharded(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class Phi(CausalLM):
|
||||
@ -30,6 +31,9 @@ class Phi(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class RW(CausalLM):
|
||||
@ -22,6 +23,9 @@ class RW(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -36,20 +40,25 @@ class RW(CausalLM):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if (
|
||||
torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
or IS_NPU_SYSTEM and torch.npu.device_count() > 1
|
||||
):
|
||||
device_map = "auto"
|
||||
else:
|
||||
device_map = None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
device_map=device_map,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
if IS_NPU_SYSTEM and torch.npu.device_count() == 1:
|
||||
model = model.npu()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional, List
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
FIM_PREFIX = "<fim-prefix>"
|
||||
FIM_MIDDLE = "<fim-middle>"
|
||||
@ -26,6 +27,9 @@ class SantaCoder(CausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
@ -16,6 +16,7 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -542,6 +543,9 @@ class Seq2SeqLM(Model):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device("npu")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
@ -549,20 +553,25 @@ class Seq2SeqLM(Model):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
if (
|
||||
torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
or IS_NPU_SYSTEM and torch.npu.is_available() > 1
|
||||
):
|
||||
device_map = "auto"
|
||||
else:
|
||||
device_map = None
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
device_map=device_map,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
if IS_NPU_SYSTEM and torch.npu.device_count() == 1:
|
||||
model = model.npu()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -17,6 +17,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class T5Sharded(Seq2SeqLM):
|
||||
@ -33,6 +34,9 @@ class T5Sharded(Seq2SeqLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_NPU_SYSTEM:
|
||||
device = torch.device(f"npu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -3,12 +3,14 @@ import torch
|
||||
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
# CUDA memory fraction
|
||||
# TODO: Do we need to rename CUDA_MEMORY_FRACTION to DEVICE_MEMORY_FRACTION?
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
|
||||
|
||||
@ -56,6 +58,15 @@ def initialize_torch_distributed():
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
elif IS_NPU_SYSTEM:
|
||||
assert WORLD_SIZE <= torch.npu.device_count(), "Each process is one npu"
|
||||
device = RANK % torch.npu.device_count()
|
||||
torch.npu.set_device(device)
|
||||
torch.npu.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
backend = "hccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
else:
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
@ -1,4 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def is_npu_available():
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||
|
||||
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||
IS_NPU_SYSTEM = is_npu_available()
|
||||
|
Loading…
Reference in New Issue
Block a user