mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Rebased.
This commit is contained in:
parent
988aa34f3d
commit
4f1b1a277c
@ -39,10 +39,6 @@ impl SchedulerV2 {
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
<<<<<<< HEAD:router/src/infer/v2/scheduler.rs
|
||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
=======
|
||||
// Infer shared state
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
@ -50,14 +46,8 @@ impl SchedulerV2 {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let block_size = std::env::var("BLOCK_SIZE")
|
||||
.map(|b| b.parse().unwrap_or(block_size))
|
||||
.unwrap_or(block_size);
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
>>>>>>> Using flash decoding:router/src/infer.rs
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
|
@ -39,9 +39,15 @@ impl SchedulerV3 {
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
@ -3,9 +3,8 @@ import torch
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
# Will be set in warmup
|
||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.models.types import (
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import (
|
||||
MEM_POOL,
|
||||
FLASH_DECODING,
|
||||
CUDA_GRAPHS,
|
||||
get_adapter_to_index,
|
||||
MODEL_ID,
|
||||
@ -46,7 +47,9 @@ from text_generation_server.utils.import_utils import (
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE: int = 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
|
||||
BLOCK_SIZE: int = (
|
||||
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
|
||||
)
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
@ -856,7 +859,23 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
if FLASH_DECODING:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
|
@ -5,10 +5,13 @@ from typing import Dict
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
if FLASH_DECODING:
|
||||
logger.info("Using FLASH_DECODING")
|
||||
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||
@ -18,8 +21,6 @@ if cuda_graphs is not None:
|
||||
)
|
||||
else:
|
||||
cuda_graphs = None
|
||||
|
||||
|
||||
# sorting the cuda graphs in descending order helps reduce the
|
||||
# memory impact and results in less memory usage
|
||||
if cuda_graphs is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user