mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
feat: add cuda memory fraction
This commit is contained in:
parent
1da642bd0e
commit
1b59f8da73
@ -245,6 +245,11 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
|
|
||||||
|
/// Limit the CUDA available memory.
|
||||||
|
/// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
|
||||||
|
#[clap(default_value = "1.0", long, env)]
|
||||||
|
cuda_memory_fraction: f32,
|
||||||
|
|
||||||
/// Outputs the logs in JSON format (useful for telemetry)
|
/// Outputs the logs in JSON format (useful for telemetry)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
@ -299,6 +304,7 @@ fn shard_manager(
|
|||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
watermark_gamma: Option<f32>,
|
watermark_gamma: Option<f32>,
|
||||||
watermark_delta: Option<f32>,
|
watermark_delta: Option<f32>,
|
||||||
|
cuda_memory_fraction: f32,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
@ -368,6 +374,12 @@ fn shard_manager(
|
|||||||
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||||
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||||
|
|
||||||
|
// CUDA memory fraction
|
||||||
|
envs.push((
|
||||||
|
"CUDA_MEMORY_FRACTION".into(),
|
||||||
|
cuda_memory_fraction.to_string().into(),
|
||||||
|
));
|
||||||
|
|
||||||
// Safetensors load fast
|
// Safetensors load fast
|
||||||
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||||
|
|
||||||
@ -771,6 +783,7 @@ fn spawn_shards(
|
|||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
let watermark_gamma = args.watermark_gamma;
|
let watermark_gamma = args.watermark_gamma;
|
||||||
let watermark_delta = args.watermark_delta;
|
let watermark_delta = args.watermark_delta;
|
||||||
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
@ -788,6 +801,7 @@ fn spawn_shards(
|
|||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
watermark_gamma,
|
watermark_gamma,
|
||||||
watermark_delta,
|
watermark_delta,
|
||||||
|
cuda_memory_fraction,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.models.types import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -738,7 +739,12 @@ class FlashCausalLM(Model):
|
|||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
free_memory, _ = torch.cuda.mem_get_info(self.device)
|
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||||
|
total_gpu_memory = (
|
||||||
|
torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory)
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
int(free_memory // total_cache_size)
|
int(free_memory // total_cache_size)
|
||||||
|
@ -21,9 +21,6 @@ class Model(ABC):
|
|||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
|
||||||
|
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
@ -4,6 +4,13 @@ import torch
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
# Tensor Parallelism settings
|
||||||
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
# CUDA memory fraction
|
||||||
|
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||||
|
|
||||||
|
|
||||||
class FakeBarrier:
|
class FakeBarrier:
|
||||||
def wait(self):
|
def wait(self):
|
||||||
@ -37,16 +44,14 @@ class FakeGroup:
|
|||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
def initialize_torch_distributed():
|
||||||
rank = int(os.getenv("RANK", "0"))
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from torch.distributed import ProcessGroupNCCL
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
|
||||||
# Set the device id.
|
# Set the device id.
|
||||||
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||||
device = rank % torch.cuda.device_count()
|
device = RANK % torch.cuda.device_count()
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
options = ProcessGroupNCCL.Options()
|
options = ProcessGroupNCCL.Options()
|
||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
@ -55,22 +60,22 @@ def initialize_torch_distributed():
|
|||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
options = None
|
options = None
|
||||||
|
|
||||||
if world_size == 1:
|
if WORLD_SIZE == 1:
|
||||||
return FakeGroup(rank, world_size), rank, world_size
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||||
else:
|
else:
|
||||||
if os.getenv("DEBUG", None) == "1":
|
if os.getenv("DEBUG", None) == "1":
|
||||||
return FakeGroup(rank, world_size), rank, world_size
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
world_size=world_size,
|
world_size=WORLD_SIZE,
|
||||||
rank=rank,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=60),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("torch.distributed is already initialized.")
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
|
||||||
return torch.distributed.group.WORLD, rank, world_size
|
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
|
||||||
|
Loading…
Reference in New Issue
Block a user