From 4fd6e62655aa4f18e716fbf258dbf1b2d624b16b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 15 Jan 2024 18:24:22 +0100 Subject: [PATCH] fix --- .../models/flash_causal_lm.py | 45 ++++++++++++++----- .../models/flash_mistral.py | 31 +++++++++++-- 2 files changed, 62 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cb777010..670ee1b5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -66,7 +66,7 @@ class FlashCausalLMBatch(Batch): # Set in prefill by the CacheManager # list of length b of list of length s_i // block_size block_tables: Optional[List[List[int]]] - # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences + # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: Optional[torch.Tensor] # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: Optional[torch.Tensor] @@ -707,6 +707,21 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph + torch.cuda.synchronize() + # Run once outside to warmup + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=None, + ) + torch.cuda.synchronize() + with torch.cuda.graph(graph, pool=MEM_POOL): self.cuda_graphs[bs]["logits"] = self.model.forward( input_ids=input_ids, @@ -719,6 +734,7 @@ class FlashCausalLM(Model): max_s=max_s, lm_head_indices=None, ) + torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive @@ -733,8 +749,8 @@ class FlashCausalLM(Model): self.dtype, self.device, ) - max_s = batch.max_seqlen max_bt = batch.max_blocks + max_s = max_bt * get_cache_manager().block_size _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( @@ -758,8 +774,8 @@ class FlashCausalLM(Model): ) num_blocks = ( - # Leave 1% for some wiggle room - int((free_memory * 0.99) // total_cache_size) + # Leave 5% for some wiggle room + int((free_memory * 0.95) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + cache_manager.num_blocks ) @@ -780,9 +796,8 @@ class FlashCausalLM(Model): if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": try: logger.info("Experimental support for Cuda Graphs is enabled") - # Warmup cuda graphs for all power of twos until 64 - for i in range(6): - bs = 2**i + # Warmup cuda graphs + for bs in [1, 2, 4] + [8 * i for i in range(8)]: if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) except Exception: @@ -790,7 +805,7 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) - def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -843,10 +858,16 @@ class FlashCausalLM(Model): lm_head_indices = batch.prefill_head_indices bs = input_ids.shape[0] - # Ceil next power of two for batch size - bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) + padded_bs = bs + if bs == 3: + padded_bs = 4 + elif 3 < bs <= 8: + padded_bs = 8 + elif bs > 8: + padded_bs = (bs + 7) // 8 * 8 + # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) + cuda_graph = self.cuda_graphs.get(padded_bs, None) if cu_seqlen_prefill is not None or cuda_graph is None: return self.model.forward( @@ -868,7 +889,9 @@ class FlashCausalLM(Model): cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables + cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths # Replay the graph diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8d7e2a2b..34a50194 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -377,6 +377,22 @@ class BaseFlashMistral(FlashCausalLM): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph + torch.cuda.synchronize() + # Run once outside to warmup + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + torch.cuda.synchronize() + with torch.cuda.graph(graph, pool=MEM_POOL): self.cuda_graphs[bs]["logits"] = self.model.forward( input_ids=input_ids, @@ -390,6 +406,7 @@ class BaseFlashMistral(FlashCausalLM): prefill_cache_indices=None, lm_head_indices=None, ) + torch.cuda.synchronize() def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward @@ -447,10 +464,16 @@ class BaseFlashMistral(FlashCausalLM): max_s = min(self.model.max_past, max_s) bs = input_ids.shape[0] - # Ceil next power of two for batch size - bs_next_power_of_two = 2 ** math.ceil(math.log2(bs)) + padded_bs = bs + if bs == 3: + padded_bs = 4 + elif 3 < bs <= 8: + padded_bs = 8 + elif bs > 8: + padded_bs = (bs + 7) // 8 * 8 + # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None) + cuda_graph = self.cuda_graphs.get(padded_bs, None) if cu_seqlen_prefill is not None or cuda_graph is None: logits = self.model.forward( @@ -476,7 +499,9 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables + cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths # Replay the graph