mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Move to cuda graphs by default (with possibility to choose graph sizes).
This commit is contained in:
parent
4ee0a0c401
commit
edcbc0890c
@ -206,12 +206,13 @@ Options:
|
||||
[env: MAX_BATCH_SIZE=]
|
||||
|
||||
```
|
||||
## ENABLE_CUDA_GRAPHS
|
||||
## CUDA_GRAPHS
|
||||
```shell
|
||||
--enable-cuda-graphs
|
||||
Enable experimental support for cuda graphs
|
||||
--cuda-graphs <CUDA_GRAPHS>
|
||||
Specify the batch sizes to compute cuda graphs for
|
||||
|
||||
[env: ENABLE_CUDA_GRAPHS=]
|
||||
[env: CUDA_GRAPHS=]
|
||||
[default: 1,2,4,8,16,32,64,96,128]
|
||||
|
||||
```
|
||||
## HOSTNAME
|
||||
|
@ -383,7 +383,6 @@ def launcher(event_loop):
|
||||
|
||||
env = {
|
||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||
"ENABLE_CUDA_GRAPHS": "true",
|
||||
}
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
@ -284,9 +284,14 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
|
||||
/// Enable experimental support for cuda graphs
|
||||
#[clap(long, env)]
|
||||
enable_cuda_graphs: bool,
|
||||
/// Specify the batch sizes to compute cuda graphs for
|
||||
#[clap(
|
||||
long,
|
||||
env,
|
||||
value_delimiter = ',',
|
||||
default_value = "1,2,4,8,16,32,64,96,128"
|
||||
)]
|
||||
cuda_graphs: Vec<usize>,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
@ -416,7 +421,7 @@ fn shard_manager(
|
||||
disable_custom_kernels: bool,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
enable_cuda_graphs: bool,
|
||||
cuda_graphs: Vec<usize>,
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
@ -549,8 +554,16 @@ fn shard_manager(
|
||||
};
|
||||
|
||||
// Enable experimental support for cuda graphs
|
||||
if enable_cuda_graphs {
|
||||
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
|
||||
if !cuda_graphs.is_empty() {
|
||||
envs.push((
|
||||
"CUDA_GRAPHS".into(),
|
||||
cuda_graphs
|
||||
.into_iter()
|
||||
.map(|c| c.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
// If disable_custom_kernels is true, pass it to the shard as an env var
|
||||
@ -941,7 +954,7 @@ fn spawn_shards(
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let enable_cuda_graphs = args.enable_cuda_graphs;
|
||||
let cuda_graphs = args.cuda_graphs.clone();
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
@ -963,7 +976,7 @@ fn spawn_shards(
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
enable_cuda_graphs,
|
||||
cuda_graphs,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
|
@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import (
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS
|
||||
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
@ -798,11 +798,11 @@ class FlashCausalLM(Model):
|
||||
self.device,
|
||||
)
|
||||
|
||||
if ENABLE_CUDA_GRAPHS:
|
||||
if CUDA_GRAPHS:
|
||||
try:
|
||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
# Warmup cuda graphs
|
||||
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
|
||||
for bs in CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except Exception:
|
||||
|
@ -3,4 +3,11 @@ import os
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
# This is overridden by the cli
|
||||
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS", "1,2,4,8,16,32,64,96,128")
|
||||
try:
|
||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||
)
|
||||
CUDA_GRAPHS = cuda_graphs
|
||||
|
@ -13,7 +13,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
|
||||
from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
|
||||
import time
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||
MambaModel,
|
||||
@ -465,12 +465,12 @@ class Mamba(Model):
|
||||
|
||||
def warmup(self, batch) -> Optional[int]:
|
||||
# TODO: implement warmup for Mamba if needed
|
||||
if ENABLE_CUDA_GRAPHS:
|
||||
if CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate == 0:
|
||||
try:
|
||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
# Warmup cuda graphs
|
||||
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
|
||||
for bs in CUDA_GRAPHS:
|
||||
self.cuda_graph_warmup(bs)
|
||||
except Exception:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
Loading…
Reference in New Issue
Block a user