mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Move to cuda graphs by default (with possibility to choose graph sizes).
This commit is contained in:
parent
4ee0a0c401
commit
edcbc0890c
@ -206,12 +206,13 @@ Options:
|
|||||||
[env: MAX_BATCH_SIZE=]
|
[env: MAX_BATCH_SIZE=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## ENABLE_CUDA_GRAPHS
|
## CUDA_GRAPHS
|
||||||
```shell
|
```shell
|
||||||
--enable-cuda-graphs
|
--cuda-graphs <CUDA_GRAPHS>
|
||||||
Enable experimental support for cuda graphs
|
Specify the batch sizes to compute cuda graphs for
|
||||||
|
|
||||||
[env: ENABLE_CUDA_GRAPHS=]
|
[env: CUDA_GRAPHS=]
|
||||||
|
[default: 1,2,4,8,16,32,64,96,128]
|
||||||
|
|
||||||
```
|
```
|
||||||
## HOSTNAME
|
## HOSTNAME
|
||||||
|
@ -383,7 +383,6 @@ def launcher(event_loop):
|
|||||||
|
|
||||||
env = {
|
env = {
|
||||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||||
"ENABLE_CUDA_GRAPHS": "true",
|
|
||||||
}
|
}
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
@ -284,9 +284,14 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
|
||||||
/// Enable experimental support for cuda graphs
|
/// Specify the batch sizes to compute cuda graphs for
|
||||||
#[clap(long, env)]
|
#[clap(
|
||||||
enable_cuda_graphs: bool,
|
long,
|
||||||
|
env,
|
||||||
|
value_delimiter = ',',
|
||||||
|
default_value = "1,2,4,8,16,32,64,96,128"
|
||||||
|
)]
|
||||||
|
cuda_graphs: Vec<usize>,
|
||||||
|
|
||||||
/// The IP address to listen on
|
/// The IP address to listen on
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
@ -416,7 +421,7 @@ fn shard_manager(
|
|||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
watermark_gamma: Option<f32>,
|
watermark_gamma: Option<f32>,
|
||||||
watermark_delta: Option<f32>,
|
watermark_delta: Option<f32>,
|
||||||
enable_cuda_graphs: bool,
|
cuda_graphs: Vec<usize>,
|
||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
rope_scaling: Option<RopeScaling>,
|
rope_scaling: Option<RopeScaling>,
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
@ -549,8 +554,16 @@ fn shard_manager(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Enable experimental support for cuda graphs
|
// Enable experimental support for cuda graphs
|
||||||
if enable_cuda_graphs {
|
if !cuda_graphs.is_empty() {
|
||||||
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
|
envs.push((
|
||||||
|
"CUDA_GRAPHS".into(),
|
||||||
|
cuda_graphs
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| c.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(",")
|
||||||
|
.into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// If disable_custom_kernels is true, pass it to the shard as an env var
|
// If disable_custom_kernels is true, pass it to the shard as an env var
|
||||||
@ -941,7 +954,7 @@ fn spawn_shards(
|
|||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
let watermark_gamma = args.watermark_gamma;
|
let watermark_gamma = args.watermark_gamma;
|
||||||
let watermark_delta = args.watermark_delta;
|
let watermark_delta = args.watermark_delta;
|
||||||
let enable_cuda_graphs = args.enable_cuda_graphs;
|
let cuda_graphs = args.cuda_graphs.clone();
|
||||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
let rope_scaling = args.rope_scaling;
|
let rope_scaling = args.rope_scaling;
|
||||||
let rope_factor = args.rope_factor;
|
let rope_factor = args.rope_factor;
|
||||||
@ -963,7 +976,7 @@ fn spawn_shards(
|
|||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
watermark_gamma,
|
watermark_gamma,
|
||||||
watermark_delta,
|
watermark_delta,
|
||||||
enable_cuda_graphs,
|
cuda_graphs,
|
||||||
cuda_memory_fraction,
|
cuda_memory_fraction,
|
||||||
rope_scaling,
|
rope_scaling,
|
||||||
rope_factor,
|
rope_factor,
|
||||||
|
@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import (
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS
|
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
@ -798,11 +798,11 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ENABLE_CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -3,4 +3,11 @@ import os
|
|||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
cuda_graphs = os.getenv("CUDA_GRAPHS", "1,2,4,8,16,32,64,96,128")
|
||||||
|
try:
|
||||||
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||||
|
)
|
||||||
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
@ -13,7 +13,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
|
from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
|
||||||
import time
|
import time
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
MambaModel,
|
MambaModel,
|
||||||
@ -465,12 +465,12 @@ class Mamba(Model):
|
|||||||
|
|
||||||
def warmup(self, batch) -> Optional[int]:
|
def warmup(self, batch) -> Optional[int]:
|
||||||
# TODO: implement warmup for Mamba if needed
|
# TODO: implement warmup for Mamba if needed
|
||||||
if ENABLE_CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate == 0:
|
if self.speculate is None or self.speculate == 0:
|
||||||
try:
|
try:
|
||||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
|
for bs in CUDA_GRAPHS:
|
||||||
self.cuda_graph_warmup(bs)
|
self.cuda_graph_warmup(bs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
Loading…
Reference in New Issue
Block a user