Move to cuda graphs by default (with possibility to choose graph sizes).

This commit is contained in:
Nicolas Patry 2024-04-04 12:46:28 +00:00
parent 4ee0a0c401
commit edcbc0890c
6 changed files with 42 additions and 22 deletions

View File

@ -206,12 +206,13 @@ Options:
[env: MAX_BATCH_SIZE=]
```
## ENABLE_CUDA_GRAPHS
## CUDA_GRAPHS
```shell
--enable-cuda-graphs
Enable experimental support for cuda graphs
--cuda-graphs <CUDA_GRAPHS>
Specify the batch sizes to compute cuda graphs for
[env: ENABLE_CUDA_GRAPHS=]
[env: CUDA_GRAPHS=]
[default: 1,2,4,8,16,32,64,96,128]
```
## HOSTNAME

View File

@ -383,7 +383,6 @@ def launcher(event_loop):
env = {
"LOG_LEVEL": "info,text_generation_router=debug",
"ENABLE_CUDA_GRAPHS": "true",
}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"

View File

@ -284,9 +284,14 @@ struct Args {
#[clap(long, env)]
max_batch_size: Option<usize>,
/// Enable experimental support for cuda graphs
#[clap(long, env)]
enable_cuda_graphs: bool,
/// Specify the batch sizes to compute cuda graphs for
#[clap(
long,
env,
value_delimiter = ',',
default_value = "1,2,4,8,16,32,64,96,128"
)]
cuda_graphs: Vec<usize>,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
@ -416,7 +421,7 @@ fn shard_manager(
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
enable_cuda_graphs: bool,
cuda_graphs: Vec<usize>,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
@ -549,8 +554,16 @@ fn shard_manager(
};
// Enable experimental support for cuda graphs
if enable_cuda_graphs {
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
if !cuda_graphs.is_empty() {
envs.push((
"CUDA_GRAPHS".into(),
cuda_graphs
.into_iter()
.map(|c| c.to_string())
.collect::<Vec<_>>()
.join(",")
.into(),
));
}
// If disable_custom_kernels is true, pass it to the shard as an env var
@ -941,7 +954,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let enable_cuda_graphs = args.enable_cuda_graphs;
let cuda_graphs = args.cuda_graphs.clone();
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
@ -963,7 +976,7 @@ fn spawn_shards(
disable_custom_kernels,
watermark_gamma,
watermark_delta,
enable_cuda_graphs,
cuda_graphs,
cuda_memory_fraction,
rope_scaling,
rope_factor,

View File

@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import (
BLOCK_SIZE,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
@ -798,11 +798,11 @@ class FlashCausalLM(Model):
self.device,
)
if ENABLE_CUDA_GRAPHS:
if CUDA_GRAPHS:
try:
logger.info("Experimental support for Cuda Graphs is enabled")
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
# Warmup cuda graphs
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception:

View File

@ -3,4 +3,11 @@ import os
MEM_POOL = torch.cuda.graph_pool_handle()
# This is overridden by the cli
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
cuda_graphs = os.getenv("CUDA_GRAPHS", "1,2,4,8,16,32,64,96,128")
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
except Exception as e:
raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
CUDA_GRAPHS = cuda_graphs

View File

@ -13,7 +13,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
import time
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaModel,
@ -465,12 +465,12 @@ class Mamba(Model):
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
if ENABLE_CUDA_GRAPHS:
if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0:
try:
logger.info("Experimental support for Cuda Graphs is enabled")
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
# Warmup cuda graphs
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
for bs in CUDA_GRAPHS:
self.cuda_graph_warmup(bs)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")