diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 53de36b2..2ad788a4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -245,6 +245,11 @@ struct Args { #[clap(long, env)] disable_custom_kernels: bool, + /// Limit the CUDA available memory. + /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction. + #[clap(default_value = "1.0", long, env)] + cuda_memory_fraction: f32, + /// Outputs the logs in JSON format (useful for telemetry) #[clap(long, env)] json_output: bool, @@ -299,6 +304,7 @@ fn shard_manager( disable_custom_kernels: bool, watermark_gamma: Option, watermark_delta: Option, + cuda_memory_fraction: f32, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -368,6 +374,12 @@ fn shard_manager( envs.push(("MASTER_PORT".into(), master_port.to_string().into())); envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); + // CUDA memory fraction + envs.push(( + "CUDA_MEMORY_FRACTION".into(), + cuda_memory_fraction.to_string().into(), + )); + // Safetensors load fast envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); @@ -771,6 +783,7 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; + let cuda_memory_fraction = args.cuda_memory_fraction; thread::spawn(move || { shard_manager( model_id, @@ -788,6 +801,7 @@ fn spawn_shards( disable_custom_kernels, watermark_gamma, watermark_delta, + cuda_memory_fraction, otlp_endpoint, status_sender, shutdown, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 547678a8..1f0e6ea3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -19,6 +19,7 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser +from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) @@ -738,7 +739,12 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - free_memory, _ = torch.cuda.mem_get_info(self.device) + total_free_memory, _ = torch.cuda.mem_get_info(self.device) + total_gpu_memory = ( + torch.cuda.get_device_properties(self.device).total_memory + ) + + free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory) num_blocks = ( int(free_memory // total_cache_size) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 89e6e99b..447371f9 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -21,9 +21,6 @@ class Model(ABC): rank: int = 0, world_size: int = 1, ): - if torch.cuda.is_available(): - torch.cuda.set_per_process_memory_fraction(1.0) - self.model = model.eval() self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 41a8e01a..d02bfc5b 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -4,6 +4,13 @@ import torch from datetime import timedelta from loguru import logger +# Tensor Parallelism settings +RANK = int(os.getenv("RANK", "0")) +WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + +# CUDA memory fraction +MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) + class FakeBarrier: def wait(self): @@ -37,16 +44,14 @@ class FakeGroup: def initialize_torch_distributed(): - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - if torch.cuda.is_available(): from torch.distributed import ProcessGroupNCCL # Set the device id. - assert world_size <= torch.cuda.device_count(), "Each process is one gpu" - device = rank % torch.cuda.device_count() + assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" + device = RANK % torch.cuda.device_count() torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True @@ -55,22 +60,22 @@ def initialize_torch_distributed(): backend = "gloo" options = None - if world_size == 1: - return FakeGroup(rank, world_size), rank, world_size + if WORLD_SIZE == 1: + return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE else: if os.getenv("DEBUG", None) == "1": - return FakeGroup(rank, world_size), rank, world_size + return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE if not torch.distributed.is_initialized(): # Call the init process. torch.distributed.init_process_group( backend=backend, - world_size=world_size, - rank=rank, + world_size=WORLD_SIZE, + rank=RANK, timeout=timedelta(seconds=60), pg_options=options, ) else: logger.warning("torch.distributed is already initialized.") - return torch.distributed.group.WORLD, rank, world_size + return torch.distributed.group.WORLD, RANK, WORLD_SIZE