mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add ascend npu support for TGI
This commit is contained in:
parent
7dbaf9e901
commit
6655717e19
@ -7,14 +7,17 @@ pub(crate) struct Env {
|
|||||||
git_sha: &'static str,
|
git_sha: &'static str,
|
||||||
docker_label: &'static str,
|
docker_label: &'static str,
|
||||||
nvidia_env: String,
|
nvidia_env: String,
|
||||||
|
npu_env: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Env {
|
impl Env {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let nvidia_env = nvidia_smi();
|
let nvidia_env = nvidia_smi();
|
||||||
|
let npu_env = npu_smi();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
||||||
|
npu_env: npu_env.unwrap_or("N/A".to_string()),
|
||||||
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
||||||
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
||||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||||
@ -31,7 +34,8 @@ impl fmt::Display for Env {
|
|||||||
writeln!(f, "Cargo version: {}", self.cargo_version)?;
|
writeln!(f, "Cargo version: {}", self.cargo_version)?;
|
||||||
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
||||||
writeln!(f, "Docker label: {}", self.docker_label)?;
|
writeln!(f, "Docker label: {}", self.docker_label)?;
|
||||||
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||||
|
write!(f, "npu-smi:\n{}", self.npu_env)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
|
|||||||
let output = nvidia_smi.replace('\n', "\n ");
|
let output = nvidia_smi.replace('\n', "\n ");
|
||||||
Some(output.trim().to_string())
|
Some(output.trim().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn npu_smi() -> Option<String> {
|
||||||
|
let output = Command::new("npu-smi info").output().ok()?;
|
||||||
|
let npu_smi = String::from_utf8(output.stdout).ok()?;
|
||||||
|
let output = npu_smi.replace('\n', "\n ");
|
||||||
|
Some(output.trim().to_string())
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
from typing import Dict, Optional, TypeVar
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch
|
from text_generation_server.models.types import Batch
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
@ -24,6 +25,8 @@ class Cache:
|
|||||||
del batch
|
del batch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
keys = list(self.cache.keys())
|
keys = list(self.cache.keys())
|
||||||
|
@ -6,6 +6,7 @@ from grpc_status import rpc_status
|
|||||||
from grpc_interceptor.server import AsyncServerInterceptor
|
from grpc_interceptor.server import AsyncServerInterceptor
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
await context.abort_with_status(
|
await context.abort_with_status(
|
||||||
rpc_status.to_status(
|
rpc_status.to_status(
|
||||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class BloomCausalLMBatch(CausalLMBatch):
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
@ -50,6 +51,9 @@ class BLOOMSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -2,6 +2,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_NPU_SYSTEM
|
||||||
|
|
||||||
BLOCK_SIZE: int = 16
|
BLOCK_SIZE: int = 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
@ -119,7 +120,10 @@ def set_cache_manager(
|
|||||||
global CACHE_MANAGER
|
global CACHE_MANAGER
|
||||||
if CACHE_MANAGER is not None:
|
if CACHE_MANAGER is not None:
|
||||||
del CACHE_MANAGER
|
del CACHE_MANAGER
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
CACHE_MANAGER = CacheManager(
|
CACHE_MANAGER = CacheManager(
|
||||||
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.models.types import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -492,6 +493,9 @@ class CausalLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
@ -506,15 +510,19 @@ class CausalLM(Model):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
or IS_NPU_SYSTEM and torch.npu.device_count() > 1
|
||||||
|
):
|
||||||
|
device_map = "auto"
|
||||||
|
else:
|
||||||
|
device_map = None
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=device_map,
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -524,6 +532,12 @@ class CausalLM(Model):
|
|||||||
and quantize != "bitsandbytes"
|
and quantize != "bitsandbytes"
|
||||||
):
|
):
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
if (
|
||||||
|
IS_NPU_SYSTEM
|
||||||
|
and torch.npu.device_count() == 1
|
||||||
|
and quantize != "bitsandbytes"
|
||||||
|
):
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||||
|
|
||||||
@ -175,6 +176,9 @@ class GalacticaSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoxSharded(CausalLM):
|
class GPTNeoxSharded(CausalLM):
|
||||||
@ -32,6 +33,9 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -23,6 +23,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class IDEFICSSharded(IdeficsCausalLM):
|
class IDEFICSSharded(IdeficsCausalLM):
|
||||||
@ -41,6 +42,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
# 9b seems to work correctly enough in float16, but 80b seems
|
# 9b seems to work correctly enough in float16, but 80b seems
|
||||||
# to be really saturating for f16.
|
# to be really saturating for f16.
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -20,6 +20,7 @@ from text_generation_server.models.types import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -580,6 +581,9 @@ class IdeficsCausalLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -30,6 +30,7 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
def new_inference_params(
|
def new_inference_params(
|
||||||
@ -422,6 +423,9 @@ class Mamba(Model):
|
|||||||
# differences while the server is under load.
|
# differences while the server is under load.
|
||||||
# This is detectable by the integration load test
|
# This is detectable by the integration load test
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -51,6 +52,9 @@ class MPTSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class OPTSharded(CausalLM):
|
class OPTSharded(CausalLM):
|
||||||
@ -30,6 +31,9 @@ class OPTSharded(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class Phi(CausalLM):
|
class Phi(CausalLM):
|
||||||
@ -30,6 +31,9 @@ class Phi(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class RW(CausalLM):
|
class RW(CausalLM):
|
||||||
@ -22,6 +23,9 @@ class RW(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
@ -36,20 +40,25 @@ class RW(CausalLM):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
or IS_NPU_SYSTEM and torch.npu.device_count() > 1
|
||||||
|
):
|
||||||
|
device_map = "auto"
|
||||||
|
else:
|
||||||
|
device_map = None
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=device_map,
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
if IS_NPU_SYSTEM and torch.npu.device_count() == 1:
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional, List
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
FIM_PREFIX = "<fim-prefix>"
|
FIM_PREFIX = "<fim-prefix>"
|
||||||
FIM_MIDDLE = "<fim-middle>"
|
FIM_MIDDLE = "<fim-middle>"
|
||||||
@ -26,6 +27,9 @@ class SantaCoder(CausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.models.types import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -542,6 +543,9 @@ class Seq2SeqLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device("npu")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
@ -549,20 +553,25 @@ class Seq2SeqLM(Model):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
or IS_NPU_SYSTEM and torch.npu.is_available() > 1
|
||||||
|
):
|
||||||
|
device_map = "auto"
|
||||||
|
else:
|
||||||
|
device_map = None
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=device_map,
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
if IS_NPU_SYSTEM and torch.npu.device_count() == 1:
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -17,6 +17,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class T5Sharded(Seq2SeqLM):
|
class T5Sharded(Seq2SeqLM):
|
||||||
@ -33,6 +34,9 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
device = torch.device(f"npu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -3,12 +3,14 @@ import torch
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||||
|
|
||||||
# Tensor Parallelism settings
|
# Tensor Parallelism settings
|
||||||
RANK = int(os.getenv("RANK", "0"))
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
# CUDA memory fraction
|
# CUDA memory fraction
|
||||||
|
# TODO: Do we need to rename CUDA_MEMORY_FRACTION to DEVICE_MEMORY_FRACTION?
|
||||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||||
|
|
||||||
|
|
||||||
@ -56,6 +58,15 @@ def initialize_torch_distributed():
|
|||||||
options = ProcessGroupNCCL.Options()
|
options = ProcessGroupNCCL.Options()
|
||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=60)
|
||||||
|
elif IS_NPU_SYSTEM:
|
||||||
|
assert WORLD_SIZE <= torch.npu.device_count(), "Each process is one npu"
|
||||||
|
device = RANK % torch.npu.device_count()
|
||||||
|
torch.npu.set_device(device)
|
||||||
|
torch.npu.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||||
|
backend = "hccl"
|
||||||
|
options = ProcessGroupNCCL.Options()
|
||||||
|
options.is_high_priority_stream = True
|
||||||
|
options._timeout = timedelta(seconds=60)
|
||||||
else:
|
else:
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
options = None
|
options = None
|
||||||
|
@ -1,4 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def is_npu_available():
|
||||||
|
try:
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||||
|
|
||||||
|
|
||||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||||
|
IS_NPU_SYSTEM = is_npu_available()
|
||||||
|
Loading…
Reference in New Issue
Block a user