feat: experimental support for cuda graphs

This commit is contained in:
OlivierDehaene 2024-01-10 16:34:39 +01:00 committed by Nicolas Patry
parent 1d929a243a
commit 15fdd40587
6 changed files with 201 additions and 34 deletions

View File

@ -317,7 +317,10 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1 gpu_count = num_shard if num_shard is not None else 1
env = {"LOG_LEVEL": "info,text_generation_router=debug"} env = {
"LOG_LEVEL": "info,text_generation_router=debug",
"ENABLE_CUDA_GRAPHS": "True",
}
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"

View File

@ -284,6 +284,10 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
/// Enable experimental support for cuda graphs
#[clap(long, env)]
enable_cuda_graphs: bool,
/// The IP address to listen on /// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
hostname: String, hostname: String,
@ -407,6 +411,7 @@ fn shard_manager(
disable_custom_kernels: bool, disable_custom_kernels: bool,
watermark_gamma: Option<f32>, watermark_gamma: Option<f32>,
watermark_delta: Option<f32>, watermark_delta: Option<f32>,
enable_cuda_graphs: bool,
cuda_memory_fraction: f32, cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>, rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>, rope_factor: Option<f32>,
@ -488,8 +493,7 @@ fn shard_manager(
envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
envs.push(("MASTER_ADDR".into(), master_addr.into())); envs.push(("MASTER_ADDR".into(), master_addr.into()));
envs.push(("MASTER_PORT".into(), master_port.to_string().into())); envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()))
// CUDA memory fraction // CUDA memory fraction
envs.push(( envs.push((
@ -539,6 +543,11 @@ fn shard_manager(
)); ));
}; };
// Enable experimental support for cuda graphs
if enable_cuda_graphs {
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
}
// If disable_custom_kernels is true, pass it to the shard as an env var // If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels { if disable_custom_kernels {
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
@ -927,6 +936,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma; let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta; let watermark_delta = args.watermark_delta;
let enable_cuda_graphs = args.enable_cuda_graphs;
let cuda_memory_fraction = args.cuda_memory_fraction; let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling; let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
@ -948,6 +958,7 @@ fn spawn_shards(
disable_custom_kernels, disable_custom_kernels,
watermark_gamma, watermark_gamma,
watermark_delta, watermark_delta,
enable_cuda_graphs,
cuda_memory_fraction, cuda_memory_fraction,
rope_scaling, rope_scaling,
rope_factor, rope_factor,

View File

@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward( def forward(
self, self,
@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
max_s = min(self.max_past, max_s) input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,

View File

@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward( def forward(
self, self,
@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
max_s = min(self.max_past, max_s) input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,

View File

@ -1,4 +1,5 @@
import math import math
import os
import time import time
import itertools import itertools
import torch import torch
@ -6,6 +7,7 @@ import torch.distributed
import numpy as np import numpy as np
from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
MEM_POOL = torch.cuda.graph_pool_handle()
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -663,6 +667,8 @@ class FlashCausalLM(Model):
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.head_size = head_size self.head_size = head_size
self.cuda_graphs = {}
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -678,7 +684,44 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=None,
)
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
torch.cuda.empty_cache() torch.cuda.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
@ -690,6 +733,8 @@ class FlashCausalLM(Model):
self.dtype, self.dtype,
self.device, self.device,
) )
max_s = batch.max_seqlen
max_bt = batch.max_blocks
_, batch, _ = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
@ -713,7 +758,8 @@ class FlashCausalLM(Model):
) )
num_blocks = ( num_blocks = (
int(free_memory // total_cache_size) # Leave 1% for some wiggle room
int((free_memory * 0.99) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory. # Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks + cache_manager.num_blocks
) )
@ -731,6 +777,14 @@ class FlashCausalLM(Model):
self.device, self.device,
) )
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
try:
# Warmup cuda graphs for all power of twos until 64
for i in range(6):
self.cuda_graph_warmup(2**i, max_s, max_bt)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
@ -785,17 +839,40 @@ class FlashCausalLM(Model):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
return self.model.forward( bs = batch.input_ids.shape[0]
input_ids=input_ids, # Ceil next power of two for batch size
position_ids=position_ids, bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
cu_seqlen_prefill=cu_seqlen_prefill, # Try to find an associated cuda graph
kv_cache=kv_cache, cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
block_tables=block_tables,
slots=slots, if batch.cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths=input_lengths, return self.model.forward(
max_s=max_s, input_ids=input_ids,
lm_head_indices=lm_head_indices, position_ids=position_ids,
) cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=lm_head_indices,
)
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(

View File

@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL = torch.cuda.graph_pool_handle()
# Adds windowing logic to FlashCausalLMBatch # Adds windowing logic to FlashCausalLMBatch
@dataclass @dataclass
@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
model = model_cls(config, weights) model = model_cls(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__( super(BaseFlashMistral, self).__init__(
model=model, model=model,
@ -350,6 +354,43 @@ class BaseFlashMistral(FlashCausalLM):
def batch_type(self) -> Type[FlashMistralBatch]: def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch return FlashMistralBatch
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
@ -401,21 +442,48 @@ class BaseFlashMistral(FlashCausalLM):
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
logits = self.model.forward(
input_ids=input_ids, if self.model.max_past is not None:
position_ids=position_ids, max_s = min(self.model.max_past, max_s)
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, bs = batch.input_ids.shape[0]
block_tables=block_tables, # Ceil next power of two for batch size
slots=slots, bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
input_lengths=input_lengths, # Try to find an associated cuda graph
max_s=max_s, cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, if batch.cu_seqlen_prefill is not None or cuda_graph is None:
) logits = self.model.forward(
if batch.prefill_cache_indices is not None: input_ids=input_ids,
batch.prefill_cache_indices = None position_ids=position_ids,
return logits cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):