From 15fdd40587510ddb8f317eceac3fd7924c2b23fe Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 10 Jan 2024 16:34:39 +0100 Subject: [PATCH] feat: experimental support for cuda graphs --- integration-tests/conftest.py | 5 +- launcher/src/main.rs | 15 ++- .../custom_modeling/flash_mistral_modeling.py | 8 +- .../custom_modeling/flash_mixtral_modeling.py | 8 +- .../models/flash_causal_lm.py | 101 +++++++++++++++--- .../models/flash_mistral.py | 98 ++++++++++++++--- 6 files changed, 201 insertions(+), 34 deletions(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 4cb4ca59..69b3fe67 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -317,7 +317,10 @@ def launcher(event_loop): gpu_count = num_shard if num_shard is not None else 1 - env = {"LOG_LEVEL": "info,text_generation_router=debug"} + env = { + "LOG_LEVEL": "info,text_generation_router=debug", + "ENABLE_CUDA_GRAPHS": "True", + } if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index cdbeb6bc..8367ef81 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -284,6 +284,10 @@ struct Args { #[clap(long, env)] max_batch_size: Option, + /// Enable experimental support for cuda graphs + #[clap(long, env)] + enable_cuda_graphs: bool, + /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] hostname: String, @@ -407,6 +411,7 @@ fn shard_manager( disable_custom_kernels: bool, watermark_gamma: Option, watermark_delta: Option, + enable_cuda_graphs: bool, cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, @@ -488,8 +493,7 @@ fn shard_manager( envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("MASTER_ADDR".into(), master_addr.into())); envs.push(("MASTER_PORT".into(), master_port.to_string().into())); - envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); - envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into())) + envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into())); // CUDA memory fraction envs.push(( @@ -539,6 +543,11 @@ fn shard_manager( )); }; + // Enable experimental support for cuda graphs + if enable_cuda_graphs { + envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into())) + } + // If disable_custom_kernels is true, pass it to the shard as an env var if disable_custom_kernels { envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) @@ -927,6 +936,7 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; + let enable_cuda_graphs = args.enable_cuda_graphs; let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; @@ -948,6 +958,7 @@ fn spawn_shards( disable_custom_kernels, watermark_gamma, watermark_delta, + enable_cuda_graphs, cuda_memory_fraction, rope_scaling, rope_factor, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 0fc4e1b3..7b45be57 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module): weights=weights, ) self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) def forward( self, @@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 61488ec4..c91b2224 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module): weights=weights, ) self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) def forward( self, @@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 90776654..21ed4f6c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,4 +1,5 @@ import math +import os import time import itertools import torch @@ -6,6 +7,7 @@ import torch.distributed import numpy as np +from loguru import logger from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase @@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) +MEM_POOL = torch.cuda.graph_pool_handle() + @dataclass class FlashCausalLMBatch(Batch): @@ -663,6 +667,8 @@ class FlashCausalLM(Model): self.num_kv_heads = num_kv_heads self.head_size = head_size + self.cuda_graphs = {} + super(FlashCausalLM, self).__init__( model=model, tokenizer=tokenizer, @@ -678,7 +684,44 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int32, device=self.device) + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + block_tables = ( + torch.arange(max_bt, dtype=torch.int32, device=self.device) + .repeat(bs) + .reshape((bs, max_bt)) + ) + kv_cache = get_cache_manager().kv_cache + + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths, + } + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs]["graph"] = graph + + with torch.cuda.graph(graph, pool=MEM_POOL): + self.cuda_graphs[bs]["logits"] = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=None, + ) + def warmup(self, batch: FlashCausalLMBatch): + # The warmup batch is the biggest batch we could ever receive torch.cuda.empty_cache() try: cache_manager = set_cache_manager( @@ -690,6 +733,8 @@ class FlashCausalLM(Model): self.dtype, self.device, ) + max_s = batch.max_seqlen + max_bt = batch.max_blocks _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( @@ -713,7 +758,8 @@ class FlashCausalLM(Model): ) num_blocks = ( - int(free_memory // total_cache_size) + # Leave 1% for some wiggle room + int((free_memory * 0.99) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + cache_manager.num_blocks ) @@ -731,6 +777,14 @@ class FlashCausalLM(Model): self.device, ) + if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + try: + # Warmup cuda graphs for all power of twos until 64 + for i in range(6): + self.cuda_graph_warmup(2**i, max_s, max_bt) + except Exception: + logger.exception(f"Decode cuda graph warmup failed") + return int(num_blocks * BLOCK_SIZE) def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: @@ -785,17 +839,40 @@ class FlashCausalLM(Model): max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - return self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - lm_head_indices=lm_head_indices, - ) + bs = batch.input_ids.shape[0] + # Ceil next power of two for batch size + bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) + # Try to find an associated cuda graph + cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) + + if batch.cu_seqlen_prefill is not None or cuda_graph is None: + return self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=lm_head_indices, + ) + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + return cuda_graph["logits"][:bs] @tracer.start_as_current_span("generate_token") def generate_token( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8c6cb025..dee272a0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__) SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None +MEM_POOL = torch.cuda.graph_pool_handle() + # Adds windowing logic to FlashCausalLMBatch @dataclass @@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM): model = model_cls(config, weights) + self.cuda_graphs = {} + torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( model=model, @@ -350,6 +354,43 @@ class BaseFlashMistral(FlashCausalLM): def batch_type(self) -> Type[FlashMistralBatch]: return FlashMistralBatch + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int32, device=self.device) + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + block_tables = ( + torch.arange(max_bt, dtype=torch.int32, device=self.device) + .repeat(bs) + .reshape((bs, max_bt)) + ) + kv_cache = get_cache_manager().kv_cache + + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths, + } + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs]["graph"] = graph + + with torch.cuda.graph(graph, pool=MEM_POOL): + self.cuda_graphs[bs]["logits"] = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward if batch.speculative_ids is not None: @@ -401,21 +442,48 @@ class BaseFlashMistral(FlashCausalLM): input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits + + if self.model.max_past is not None: + max_s = min(self.model.max_past, max_s) + + bs = batch.input_ids.shape[0] + # Ceil next power of two for batch size + bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) + # Try to find an associated cuda graph + cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) + + if batch.cu_seqlen_prefill is not None or cuda_graph is None: + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + return cuda_graph["logits"][:bs] class FlashMistral(BaseFlashMistral):