From 4f1b1a277caa8b137560c687d8666f914f835d24 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jun 2024 12:20:50 +0000 Subject: [PATCH] Rebased. --- router/src/infer/v2/scheduler.rs | 12 +--------- router/src/infer/v3/scheduler.rs | 8 ++++++- .../models/cache_manager.py | 3 +-- .../models/flash_causal_lm.py | 23 +++++++++++++++++-- .../text_generation_server/models/globals.py | 7 +++--- 5 files changed, 34 insertions(+), 19 deletions(-) diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 926de0fa..e4c3de26 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -39,10 +39,6 @@ impl SchedulerV2 { speculate: u32, generation_health: Arc, ) -> Self { -<<<<<<< HEAD:router/src/infer/v2/scheduler.rs - let queue = Queue::new(requires_padding, 16, window_size, speculate); - let batching_task_notifier = Arc::new(Notify::new()); -======= // Infer shared state let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") @@ -50,14 +46,8 @@ impl SchedulerV2 { false }; let block_size = if flashdecoding { 256 } else { 16 }; - let block_size = std::env::var("BLOCK_SIZE") - .map(|b| b.parse().unwrap_or(block_size)) - .unwrap_or(block_size); let queue = Queue::new(requires_padding, block_size, window_size, speculate); - let shared = Arc::new(Shared { - batching_task: Notify::new(), - }); ->>>>>>> Using flash decoding:router/src/infer.rs + let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd83..543ce89f 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -39,9 +39,15 @@ impl SchedulerV3 { speculate: u32, generation_health: Arc, ) -> Self { + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, - 16, + block_size, window_size, speculate, max_batch_total_tokens, diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index df6b1ade..518f9abb 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -3,9 +3,8 @@ import torch from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a3687d95..fa78ee22 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -30,6 +30,7 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, + FLASH_DECODING, CUDA_GRAPHS, get_adapter_to_index, MODEL_ID, @@ -46,7 +47,9 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) -BLOCK_SIZE: int = 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16 +BLOCK_SIZE: int = ( + 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16 +) # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -856,7 +859,23 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if SYSTEM == "ipex" and device == torch.device("cpu"): + if FLASH_DECODING: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + elif SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( torch.empty( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 11693436..06035ccd 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,10 +5,13 @@ from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli -cuda_graphs = os.getenv("CUDA_GRAPHS") FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: logger.info("Using FLASH_DECODING") + + +cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] @@ -18,8 +21,6 @@ if cuda_graphs is not None: ) else: cuda_graphs = None - - # sorting the cuda graphs in descending order helps reduce the # memory impact and results in less memory usage if cuda_graphs is not None: