mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Rebased.
This commit is contained in:
parent
988aa34f3d
commit
4f1b1a277c
@ -39,10 +39,6 @@ impl SchedulerV2 {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
<<<<<<< HEAD:router/src/infer/v2/scheduler.rs
|
|
||||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
|
||||||
=======
|
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||||
@ -50,14 +46,8 @@ impl SchedulerV2 {
|
|||||||
false
|
false
|
||||||
};
|
};
|
||||||
let block_size = if flashdecoding { 256 } else { 16 };
|
let block_size = if flashdecoding { 256 } else { 16 };
|
||||||
let block_size = std::env::var("BLOCK_SIZE")
|
|
||||||
.map(|b| b.parse().unwrap_or(block_size))
|
|
||||||
.unwrap_or(block_size);
|
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
let shared = Arc::new(Shared {
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
batching_task: Notify::new(),
|
|
||||||
});
|
|
||||||
>>>>>>> Using flash decoding:router/src/infer.rs
|
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
|
@ -39,9 +39,15 @@ impl SchedulerV3 {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||||
|
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
let block_size = if flashdecoding { 256 } else { 16 };
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
16,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
@ -3,9 +3,8 @@ import torch
|
|||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||||
|
|
||||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.globals import (
|
from text_generation_server.models.globals import (
|
||||||
MEM_POOL,
|
MEM_POOL,
|
||||||
|
FLASH_DECODING,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
@ -46,7 +47,9 @@ from text_generation_server.utils.import_utils import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
BLOCK_SIZE: int = 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
|
BLOCK_SIZE: int = (
|
||||||
|
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
|
||||||
|
)
|
||||||
|
|
||||||
# Will be set in init
|
# Will be set in init
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
@ -856,7 +859,23 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
if FLASH_DECODING:
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
|
@ -5,10 +5,13 @@ from typing import Dict
|
|||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
|
||||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||||
|
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
logger.info("Using FLASH_DECODING")
|
logger.info("Using FLASH_DECODING")
|
||||||
|
|
||||||
|
|
||||||
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
try:
|
try:
|
||||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||||
@ -18,8 +21,6 @@ if cuda_graphs is not None:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cuda_graphs = None
|
cuda_graphs = None
|
||||||
|
|
||||||
|
|
||||||
# sorting the cuda graphs in descending order helps reduce the
|
# sorting the cuda graphs in descending order helps reduce the
|
||||||
# memory impact and results in less memory usage
|
# memory impact and results in less memory usage
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user