This commit is contained in:
Nicolas Patry 2024-06-25 12:20:50 +00:00
parent 988aa34f3d
commit 4f1b1a277c
5 changed files with 34 additions and 19 deletions

View File

@ -39,10 +39,6 @@ impl SchedulerV2 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
<<<<<<< HEAD:router/src/infer/v2/scheduler.rs
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
=======
// Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
@ -50,14 +46,8 @@ impl SchedulerV2 {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let block_size = std::env::var("BLOCK_SIZE")
.map(|b| b.parse().unwrap_or(block_size))
.unwrap_or(block_size);
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
>>>>>>> Using flash decoding:router/src/infer.rs
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(

View File

@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
16,
block_size,
window_size,
speculate,
max_batch_total_tokens,

View File

@ -3,9 +3,8 @@ import torch
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None

View File

@ -30,6 +30,7 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
MEM_POOL,
FLASH_DECODING,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
@ -46,7 +47,9 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__)
BLOCK_SIZE: int = 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
BLOCK_SIZE: int = (
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
@ -856,7 +859,23 @@ class FlashCausalLM(Model):
else:
x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"):
if FLASH_DECODING:
self.kv_cache = [
(
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(

View File

@ -5,10 +5,13 @@ from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
logger.info("Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
@ -18,8 +21,6 @@ if cuda_graphs is not None:
)
else:
cuda_graphs = None
# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None: