mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
feat: add cuda memory fraction
This commit is contained in:
parent
1da642bd0e
commit
1b59f8da73
@ -245,6 +245,11 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
disable_custom_kernels: bool,
|
||||
|
||||
/// Limit the CUDA available memory.
|
||||
/// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
|
||||
#[clap(default_value = "1.0", long, env)]
|
||||
cuda_memory_fraction: f32,
|
||||
|
||||
/// Outputs the logs in JSON format (useful for telemetry)
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
@ -299,6 +304,7 @@ fn shard_manager(
|
||||
disable_custom_kernels: bool,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
cuda_memory_fraction: f32,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
@ -368,6 +374,12 @@ fn shard_manager(
|
||||
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||
|
||||
// CUDA memory fraction
|
||||
envs.push((
|
||||
"CUDA_MEMORY_FRACTION".into(),
|
||||
cuda_memory_fraction.to_string().into(),
|
||||
));
|
||||
|
||||
// Safetensors load fast
|
||||
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||
|
||||
@ -771,6 +783,7 @@ fn spawn_shards(
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
@ -788,6 +801,7 @@ fn spawn_shards(
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
cuda_memory_fraction,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -738,7 +739,12 @@ class FlashCausalLM(Model):
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
total_gpu_memory = (
|
||||
torch.cuda.get_device_properties(self.device).total_memory
|
||||
)
|
||||
|
||||
free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory)
|
||||
|
||||
num_blocks = (
|
||||
int(free_memory // total_cache_size)
|
||||
|
@ -21,9 +21,6 @@ class Model(ABC):
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||
|
||||
self.model = model.eval()
|
||||
self.tokenizer = tokenizer
|
||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||
|
@ -4,6 +4,13 @@ import torch
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
# CUDA memory fraction
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
def wait(self):
|
||||
@ -37,16 +44,14 @@ class FakeGroup:
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
# Set the device id.
|
||||
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
||||
device = rank % torch.cuda.device_count()
|
||||
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
@ -55,22 +60,22 @@ def initialize_torch_distributed():
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
||||
if world_size == 1:
|
||||
return FakeGroup(rank, world_size), rank, world_size
|
||||
if WORLD_SIZE == 1:
|
||||
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||
else:
|
||||
if os.getenv("DEBUG", None) == "1":
|
||||
return FakeGroup(rank, world_size), rank, world_size
|
||||
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
# Call the init process.
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=60),
|
||||
pg_options=options,
|
||||
)
|
||||
else:
|
||||
logger.warning("torch.distributed is already initialized.")
|
||||
|
||||
return torch.distributed.group.WORLD, rank, world_size
|
||||
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
|
||||
|
Loading…
Reference in New Issue
Block a user