mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: experimental support for cuda graphs
This commit is contained in:
parent
1d929a243a
commit
15fdd40587
@ -317,7 +317,10 @@ def launcher(event_loop):
|
||||
|
||||
gpu_count = num_shard if num_shard is not None else 1
|
||||
|
||||
env = {"LOG_LEVEL": "info,text_generation_router=debug"}
|
||||
env = {
|
||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||
"ENABLE_CUDA_GRAPHS": "True",
|
||||
}
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
||||
|
@ -284,6 +284,10 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
|
||||
/// Enable experimental support for cuda graphs
|
||||
#[clap(long, env)]
|
||||
enable_cuda_graphs: bool,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
@ -407,6 +411,7 @@ fn shard_manager(
|
||||
disable_custom_kernels: bool,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
enable_cuda_graphs: bool,
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
@ -488,8 +493,7 @@ fn shard_manager(
|
||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||
envs.push(("MASTER_ADDR".into(), master_addr.into()));
|
||||
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||
envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()))
|
||||
envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
|
||||
|
||||
// CUDA memory fraction
|
||||
envs.push((
|
||||
@ -539,6 +543,11 @@ fn shard_manager(
|
||||
));
|
||||
};
|
||||
|
||||
// Enable experimental support for cuda graphs
|
||||
if enable_cuda_graphs {
|
||||
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
|
||||
}
|
||||
|
||||
// If disable_custom_kernels is true, pass it to the shard as an env var
|
||||
if disable_custom_kernels {
|
||||
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||
@ -927,6 +936,7 @@ fn spawn_shards(
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let enable_cuda_graphs = args.enable_cuda_graphs;
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
@ -948,6 +958,7 @@ fn spawn_shards(
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
enable_cuda_graphs,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
|
@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(self.max_past, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(self.max_past, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import itertools
|
||||
import torch
|
||||
@ -6,6 +7,7 @@ import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
@ -663,6 +667,8 @@ class FlashCausalLM(Model):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@ -678,7 +684,44 @@ class FlashCausalLM(Model):
|
||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||
return FlashCausalLMBatch
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
.reshape((bs, max_bt))
|
||||
)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
cache_manager = set_cache_manager(
|
||||
@ -690,6 +733,8 @@ class FlashCausalLM(Model):
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
max_s = batch.max_seqlen
|
||||
max_bt = batch.max_blocks
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
@ -713,7 +758,8 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
num_blocks = (
|
||||
int(free_memory // total_cache_size)
|
||||
# Leave 1% for some wiggle room
|
||||
int((free_memory * 0.99) // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ cache_manager.num_blocks
|
||||
)
|
||||
@ -731,6 +777,14 @@ class FlashCausalLM(Model):
|
||||
self.device,
|
||||
)
|
||||
|
||||
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
|
||||
try:
|
||||
# Warmup cuda graphs for all power of twos until 64
|
||||
for i in range(6):
|
||||
self.cuda_graph_warmup(2**i, max_s, max_bt)
|
||||
except Exception:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -785,6 +839,13 @@ class FlashCausalLM(Model):
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
bs = batch.input_ids.shape[0]
|
||||
# Ceil next power of two for batch size
|
||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||
|
||||
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -797,6 +858,22 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: FlashCausalLMBatch
|
||||
|
@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
|
||||
|
||||
# Adds windowing logic to FlashCausalLMBatch
|
||||
@dataclass
|
||||
@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
|
||||
model = model_cls(config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model=model,
|
||||
@ -350,6 +354,43 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
.reshape((bs, max_bt))
|
||||
)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
@ -401,6 +442,17 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if self.model.max_past is not None:
|
||||
max_s = min(self.model.max_past, max_s)
|
||||
|
||||
bs = batch.input_ids.shape[0]
|
||||
# Ceil next power of two for batch size
|
||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||
|
||||
if batch.cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -417,6 +469,22 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
batch.prefill_cache_indices = None
|
||||
return logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
def __init__(
|
||||
|
Loading…
Reference in New Issue
Block a user