mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix
This commit is contained in:
parent
33e94379c8
commit
4fd6e62655
@ -66,7 +66,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Set in prefill by the CacheManager
|
# Set in prefill by the CacheManager
|
||||||
# list of length b of list of length s_i // block_size
|
# list of length b of list of length s_i // block_size
|
||||||
block_tables: Optional[List[List[int]]]
|
block_tables: Optional[List[List[int]]]
|
||||||
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
block_tables_tensor: Optional[torch.Tensor]
|
block_tables_tensor: Optional[torch.Tensor]
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: Optional[torch.Tensor]
|
slots: Optional[torch.Tensor]
|
||||||
@ -707,6 +707,21 @@ class FlashCausalLM(Model):
|
|||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Run once outside to warmup
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -719,6 +734,7 @@ class FlashCausalLM(Model):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
@ -733,8 +749,8 @@ class FlashCausalLM(Model):
|
|||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
max_s = batch.max_seqlen
|
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
|
max_s = max_bt * get_cache_manager().block_size
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -758,8 +774,8 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 1% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
int((free_memory * 0.99) // total_cache_size)
|
int((free_memory * 0.95) // total_cache_size)
|
||||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ cache_manager.num_blocks
|
+ cache_manager.num_blocks
|
||||||
)
|
)
|
||||||
@ -780,9 +796,8 @@ class FlashCausalLM(Model):
|
|||||||
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
|
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
|
||||||
try:
|
try:
|
||||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||||
# Warmup cuda graphs for all power of twos until 64
|
# Warmup cuda graphs
|
||||||
for i in range(6):
|
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
|
||||||
bs = 2**i
|
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -790,7 +805,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
@ -843,10 +858,16 @@ class FlashCausalLM(Model):
|
|||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
# Ceil next power of two for batch size
|
padded_bs = bs
|
||||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
if bs == 3:
|
||||||
|
padded_bs = 4
|
||||||
|
elif 3 < bs <= 8:
|
||||||
|
padded_bs = 8
|
||||||
|
elif bs > 8:
|
||||||
|
padded_bs = (bs + 7) // 8 * 8
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@ -868,7 +889,9 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
|
@ -377,6 +377,22 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Run once outside to warmup
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -390,6 +406,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
@ -447,10 +464,16 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
max_s = min(self.model.max_past, max_s)
|
max_s = min(self.model.max_past, max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
# Ceil next power of two for batch size
|
padded_bs = bs
|
||||||
bs_next_power_of_two = 2 ** math.ceil(math.log2(bs))
|
if bs == 3:
|
||||||
|
padded_bs = 4
|
||||||
|
elif 3 < bs <= 8:
|
||||||
|
padded_bs = 8
|
||||||
|
elif bs > 8:
|
||||||
|
padded_bs = (bs + 7) // 8 * 8
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(bs_next_power_of_two, None)
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
@ -476,7 +499,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
|
Loading…
Reference in New Issue
Block a user