feat: experimental support for cuda graphs

This commit is contained in:
OlivierDehaene 2024-01-10 16:34:39 +01:00 committed by Nicolas Patry
parent 1d929a243a
commit 15fdd40587
6 changed files with 201 additions and 34 deletions

View File

@ -317,7 +317,10 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1
env = {"LOG_LEVEL": "info,text_generation_router=debug"}
env = {
"LOG_LEVEL": "info,text_generation_router=debug",
"ENABLE_CUDA_GRAPHS": "True",
}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"

View File

@ -284,6 +284,10 @@ struct Args {
#[clap(long, env)]
max_batch_size: Option<usize>,
/// Enable experimental support for cuda graphs
#[clap(long, env)]
enable_cuda_graphs: bool,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
@ -407,6 +411,7 @@ fn shard_manager(
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
enable_cuda_graphs: bool,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
@ -488,8 +493,7 @@ fn shard_manager(
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
envs.push(("MASTER_ADDR".into(), master_addr.into()));
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()))
envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
// CUDA memory fraction
envs.push((
@ -539,6 +543,11 @@ fn shard_manager(
));
};
// Enable experimental support for cuda graphs
if enable_cuda_graphs {
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
}
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
@ -927,6 +936,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let enable_cuda_graphs = args.enable_cuda_graphs;
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
@ -948,6 +958,7 @@ fn spawn_shards(
disable_custom_kernels,
watermark_gamma,
watermark_delta,
enable_cuda_graphs,
cuda_memory_fraction,
rope_scaling,
rope_factor,

View File

@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
hidden_states = self.model(
input_ids,

View File

@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
hidden_states = self.model(
input_ids,

View File

@ -1,4 +1,5 @@
import math
import os
import time
import itertools
import torch
@ -6,6 +7,7 @@ import torch.distributed
import numpy as np
from loguru import logger
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
MEM_POOL = torch.cuda.graph_pool_handle()
@dataclass
class FlashCausalLMBatch(Batch):
@ -663,6 +667,8 @@ class FlashCausalLM(Model):
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.cuda_graphs = {}
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
@ -678,7 +684,44 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=None,
)
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
torch.cuda.empty_cache()
try:
cache_manager = set_cache_manager(
@ -690,6 +733,8 @@ class FlashCausalLM(Model):
self.dtype,
self.device,
)
max_s = batch.max_seqlen
max_bt = batch.max_blocks
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
@ -713,7 +758,8 @@ class FlashCausalLM(Model):
)
num_blocks = (
int(free_memory // total_cache_size)
# Leave 1% for some wiggle room
int((free_memory * 0.99) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks
)
@ -731,6 +777,14 @@ class FlashCausalLM(Model):
self.device,
)
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
try:
# Warmup cuda graphs for all power of twos until 64
for i in range(6):
self.cuda_graph_warmup(2**i, max_s, max_bt)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
@ -785,6 +839,13 @@ class FlashCausalLM(Model):
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
bs = batch.input_ids.shape[0]
# Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -797,6 +858,22 @@ class FlashCausalLM(Model):
lm_head_indices=lm_head_indices,
)
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch

View File

@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL = torch.cuda.graph_pool_handle()
# Adds windowing logic to FlashCausalLMBatch
@dataclass
@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
model = model_cls(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model=model,
@ -350,6 +354,43 @@ class BaseFlashMistral(FlashCausalLM):
def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
@ -401,6 +442,17 @@ class BaseFlashMistral(FlashCausalLM):
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if self.model.max_past is not None:
max_s = min(self.model.max_past, max_s)
bs = batch.input_ids.shape[0]
# Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -417,6 +469,22 @@ class BaseFlashMistral(FlashCausalLM):
batch.prefill_cache_indices = None
return logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
class FlashMistral(BaseFlashMistral):
def __init__(