feat: add cuda memory fraction

This commit is contained in:
OlivierDehaene 2023-07-20 11:29:48 +02:00
parent 1da642bd0e
commit 1b59f8da73
4 changed files with 37 additions and 15 deletions

View File

@ -245,6 +245,11 @@ struct Args {
#[clap(long, env)]
disable_custom_kernels: bool,
/// Limit the CUDA available memory.
/// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
#[clap(default_value = "1.0", long, env)]
cuda_memory_fraction: f32,
/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
@ -299,6 +304,7 @@ fn shard_manager(
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
@ -368,6 +374,12 @@ fn shard_manager(
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// CUDA memory fraction
envs.push((
"CUDA_MEMORY_FRACTION".into(),
cuda_memory_fraction.to_string().into(),
));
// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
@ -771,6 +783,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
thread::spawn(move || {
shard_manager(
model_id,
@ -788,6 +801,7 @@ fn spawn_shards(
disable_custom_kernels,
watermark_gamma,
watermark_delta,
cuda_memory_fraction,
otlp_endpoint,
status_sender,
shutdown,

View File

@ -19,6 +19,7 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
@ -738,7 +739,12 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory, _ = torch.cuda.mem_get_info(self.device)
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = (
torch.cuda.get_device_properties(self.device).total_memory
)
free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory)
num_blocks = (
int(free_memory // total_cache_size)

View File

@ -21,9 +21,6 @@ class Model(ABC):
rank: int = 0,
world_size: int = 1,
):
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(1.0)
self.model = model.eval()
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)

View File

@ -4,6 +4,13 @@ import torch
from datetime import timedelta
from loguru import logger
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
# CUDA memory fraction
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
class FakeBarrier:
def wait(self):
@ -37,16 +44,14 @@ class FakeGroup:
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
@ -55,22 +60,22 @@ def initialize_torch_distributed():
backend = "gloo"
options = None
if world_size == 1:
return FakeGroup(rank, world_size), rank, world_size
if WORLD_SIZE == 1:
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
else:
if os.getenv("DEBUG", None) == "1":
return FakeGroup(rank, world_size), rank, world_size
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
if not torch.distributed.is_initialized():
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else:
logger.warning("torch.distributed is already initialized.")
return torch.distributed.group.WORLD, rank, world_size
return torch.distributed.group.WORLD, RANK, WORLD_SIZE