This commit is contained in:
OlivierDehaene 2024-01-15 18:24:22 +01:00 committed by Nicolas Patry
parent 33e94379c8
commit 4fd6e62655
2 changed files with 62 additions and 14 deletions

View File

@ -66,7 +66,7 @@ class FlashCausalLMBatch(Batch):
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor]
@ -707,6 +707,21 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
@ -719,6 +734,7 @@ class FlashCausalLM(Model):
max_s=max_s,
lm_head_indices=None,
)
torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
@ -733,8 +749,8 @@ class FlashCausalLM(Model):
self.dtype,
self.device,
)
max_s = batch.max_seqlen
max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
@ -758,8 +774,8 @@ class FlashCausalLM(Model):
)
num_blocks = (
# Leave 1% for some wiggle room
int((free_memory * 0.99) // total_cache_size)
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks
)
@ -780,9 +796,8 @@ class FlashCausalLM(Model):
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
try:
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs for all power of twos until 64
for i in range(6):
bs = 2**i
# Warmup cuda graphs
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception:
@ -790,7 +805,7 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
@ -843,10 +858,16 @@ class FlashCausalLM(Model):
lm_head_indices = batch.prefill_head_indices
bs = input_ids.shape[0]
# Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward(
@ -868,7 +889,9 @@ class FlashCausalLM(Model):
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph

View File

@ -377,6 +377,22 @@ class BaseFlashMistral(FlashCausalLM):
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
@ -390,6 +406,7 @@ class BaseFlashMistral(FlashCausalLM):
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
@ -447,10 +464,16 @@ class BaseFlashMistral(FlashCausalLM):
max_s = min(self.model.max_past, max_s)
bs = input_ids.shape[0]
# Ceil next power of two for batch size
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward(
@ -476,7 +499,9 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph