mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Use a block size of 1 for FlashInfer
This commit is contained in:
parent
8fb8e1da78
commit
4562c16048
@ -40,7 +40,18 @@ impl BackendV3 {
|
|||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
};
|
};
|
||||||
let block_size = if flashdecoding { 256 } else { 16 };
|
let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") {
|
||||||
|
matches!(flashinfer.to_lowercase().as_str(), "1" | "true")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
let block_size = if flashdecoding {
|
||||||
|
256
|
||||||
|
} else if flashinfer {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
16
|
||||||
|
};
|
||||||
|
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
|
@ -45,7 +45,19 @@ impl BackendV2 {
|
|||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
};
|
};
|
||||||
let block_size = if flashdecoding { 256 } else { 16 };
|
let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") {
|
||||||
|
matches!(flashinfer.to_lowercase().as_str(), "1" | "true")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
let block_size = if flashdecoding {
|
||||||
|
256
|
||||||
|
} else if flashinfer {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
16
|
||||||
|
};
|
||||||
|
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
@ -5,16 +5,22 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
# This is overridden by the cli
|
||||||
|
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||||
|
if FLASH_DECODING:
|
||||||
|
log_master(logger.info, "Using FLASH_DECODING")
|
||||||
|
|
||||||
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
|
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
|
||||||
if FLASH_INFER:
|
if FLASH_INFER:
|
||||||
log_master(logger.info, "Using FLASH_INFER")
|
log_master(logger.info, "Using FLASH_INFER")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
|
||||||
# This is overridden by the cli
|
|
||||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
|
||||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
log_master(logger.info, "Using FLASH_DECODING")
|
BLOCK_SIZE = 256
|
||||||
|
elif FLASH_INFER:
|
||||||
|
BLOCK_SIZE = 1
|
||||||
|
else:
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
|
Loading…
Reference in New Issue
Block a user