mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix
This commit is contained in:
parent
33e94379c8
commit
4fd6e62655
@ -66,7 +66,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Set in prefill by the CacheManager
|
||||
# list of length b of list of length s_i // block_size
|
||||
block_tables: Optional[List[List[int]]]
|
||||
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||
block_tables_tensor: Optional[torch.Tensor]
|
||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||
slots: Optional[torch.Tensor]
|
||||
@ -707,6 +707,21 @@ class FlashCausalLM(Model):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -719,6 +734,7 @@ class FlashCausalLM(Model):
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
@ -733,8 +749,8 @@ class FlashCausalLM(Model):
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
max_s = batch.max_seqlen
|
||||
max_bt = batch.max_blocks
|
||||
max_s = max_bt * get_cache_manager().block_size
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
@ -758,8 +774,8 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
num_blocks = (
|
||||
# Leave 1% for some wiggle room
|
||||
int((free_memory * 0.99) // total_cache_size)
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * 0.95) // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ cache_manager.num_blocks
|
||||
)
|
||||
@ -780,9 +796,8 @@ class FlashCausalLM(Model):
|
||||
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
|
||||
try:
|
||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||
# Warmup cuda graphs for all power of twos until 64
|
||||
for i in range(6):
|
||||
bs = 2**i
|
||||
# Warmup cuda graphs
|
||||
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except Exception:
|
||||
@ -790,7 +805,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
@ -843,10 +858,16 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Ceil next power of two for batch size
|
||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
return self.model.forward(
|
||||
@ -868,7 +889,9 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
|
@ -377,6 +377,22 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -390,6 +406,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
@ -447,10 +464,16 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
max_s = min(self.model.max_past, max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Ceil next power of two for batch size
|
||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits = self.model.forward(
|
||||
@ -476,7 +499,9 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
|
Loading…
Reference in New Issue
Block a user