This commit is contained in:
Nicolas Patry 2024-06-25 12:20:50 +00:00
parent 988aa34f3d
commit 4f1b1a277c
5 changed files with 34 additions and 19 deletions

View File

@ -39,10 +39,6 @@ impl SchedulerV2 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
<<<<<<< HEAD:router/src/infer/v2/scheduler.rs
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
=======
// Infer shared state // Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
@ -50,14 +46,8 @@ impl SchedulerV2 {
false false
}; };
let block_size = if flashdecoding { 256 } else { 16 }; let block_size = if flashdecoding { 256 } else { 16 };
let block_size = std::env::var("BLOCK_SIZE")
.map(|b| b.parse().unwrap_or(block_size))
.unwrap_or(block_size);
let queue = Queue::new(requires_padding, block_size, window_size, speculate); let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let shared = Arc::new(Shared { let batching_task_notifier = Arc::new(Notify::new());
batching_task: Notify::new(),
});
>>>>>>> Using flash decoding:router/src/infer.rs
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task( tokio::spawn(batching_task(

View File

@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
16, block_size,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,

View File

@ -3,9 +3,8 @@ import torch
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
# Will be set in warmup # Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None CACHE_MANAGER: Optional["CacheManager"] = None

View File

@ -30,6 +30,7 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
MEM_POOL, MEM_POOL,
FLASH_DECODING,
CUDA_GRAPHS, CUDA_GRAPHS,
get_adapter_to_index, get_adapter_to_index,
MODEL_ID, MODEL_ID,
@ -46,7 +47,9 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
BLOCK_SIZE: int = 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16 BLOCK_SIZE: int = (
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
)
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
@ -856,7 +859,23 @@ class FlashCausalLM(Model):
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"): if FLASH_DECODING:
self.kv_cache = [
(
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.empty(

View File

@ -5,10 +5,13 @@ from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING: if FLASH_DECODING:
logger.info("Using FLASH_DECODING") logger.info("Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
try: try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")] cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
@ -18,8 +21,6 @@ if cuda_graphs is not None:
) )
else: else:
cuda_graphs = None cuda_graphs = None
# sorting the cuda graphs in descending order helps reduce the # sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage # memory impact and results in less memory usage
if cuda_graphs is not None: if cuda_graphs is not None: