pingpong optimization

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-08 19:56:10 -07:00
parent 29703dbd27
commit cd900c3b72
3 changed files with 274 additions and 345 deletions

View File

@ -253,6 +253,9 @@ class FlashCausalLMBatch(Batch):
hpu_attn_meta: Optional[HPUPagedAttentionMetadata] hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
next_token_logits: Optional[torch.Tensor]
speculative_logits: Optional[torch.Tensor]
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
@ -490,6 +493,8 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor=None, input_lengths_tensor=None,
adapter_meta=None, adapter_meta=None,
hpu_attn_meta=None, hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
) )
@classmethod @classmethod
@ -698,6 +703,8 @@ class FlashCausalLMBatch(Batch):
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=adapter_meta, adapter_meta=adapter_meta,
hpu_attn_meta=None, hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
) )
@classmethod @classmethod
@ -959,6 +966,8 @@ class FlashCausalLMBatch(Batch):
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=adapter_meta, adapter_meta=adapter_meta,
hpu_attn_meta=None, hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
) )
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
@ -1484,7 +1493,7 @@ class FlashCausalLM(Model):
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None: if max_total_tokens is None:
max_total_tokens = sum(batch.cache_lengths) max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None: if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1 max_input_tokens = max_total_tokens - 1
@ -1531,6 +1540,8 @@ class FlashCausalLM(Model):
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets) reversed(self.bucketing_ctx.decode_buckets)
): ):
if batch_size > block_num:
continue
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )
@ -1803,6 +1814,144 @@ class FlashCausalLM(Model):
def generate_token( def generate_token(
self, batches: List[FlashCausalLMBatch] self, batches: List[FlashCausalLMBatch]
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
# Stage 1. Collect next token ids of any previously started generations
prev_batches = []
requests_to_generate = []
for batch_id, batch in enumerate(batches):
if batch.next_token_logits is not None:
prefill = batch.prefilling
if batch.prefilling:
batch.prefilling = False
batch.prefilling_mask = [False] * len(batch)
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_current_length],
batch.next_token_logits,
speculate,
batch.speculative_ids,
batch.speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
logprobs,
accepted_ids,
)
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill:
indices = batch.cu_seqlen_prefill[1:] - 1
# pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[
batch.prefill_cache_indices
][indices]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = (
batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch
# Cumulative length
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
if batch.speculative_logits is not None:
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
dtype=torch.long,
device=batch.input_lengths_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids
# Does a HPU <-> CPU sync internally
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices
)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
prev_batches.append(
{
"next_token_ids": next_input_ids,
"next_token_logprobs": next_token_logprobs,
"accepted_ids": accepted_ids,
}
)
idx = len(prev_batches) - 1
for req_idx, req in enumerate(batch.requests):
requests_to_generate.append(
{
"idx": idx,
"request_id": req.id,
"cache_length": batch.cache_lengths[req_idx],
"input_length": batch.input_lengths[req_idx],
"prefix_offset": batch.prefix_offsets[req_idx],
"read_offset": batch.read_offsets[req_idx],
"stopping_criteria": batch.stopping_criterias[req_idx],
"all_input_ids": batch.all_input_ids[req_idx],
"do_sample": batch.next_token_chooser.do_sample[req_idx],
"seed": batch.next_token_chooser.seeds[req_idx],
"top_n_tokens": batch.top_n_tokens[req_idx],
"top_token_ids": batch_top_token_ids[req_idx],
"top_token_logprobs": batch_top_token_logprobs[req_idx],
}
)
if prefill:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.next_token_logits = None
batch.speculative_ids = None
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1: if len(batches) > 1:
batch = self.batch_type.concatenate(batches) batch = self.batch_type.concatenate(batches)
else: else:
@ -1851,7 +2000,7 @@ class FlashCausalLM(Model):
out, speculative_logits = self.forward(batch, adapter_data) out, speculative_logits = self.forward(batch, adapter_data)
if prefill: if prefill:
next_token_logits = ( batch.next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
if speculative_logits is not None: if speculative_logits is not None:
@ -1862,364 +2011,147 @@ class FlashCausalLM(Model):
) )
else: else:
prefill_logprobs = None prefill_logprobs = None
next_token_logits = out batch.next_token_logits = out
batch.speculative_logits = speculative_logits
finished_prefilling = True # HPU->CPU sync
next_chunk_lengths = [] for prev_batch in prev_batches:
current_prefilling_mask = batch.prefilling_mask prev_batch["next_token_logprobs"] = prev_batch[
if prefill: "next_token_logprobs"
finished_prefilling = True ].tolist()
next_prefilling_mask = [False] * len(batch) prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
batch.prefilling = not finished_prefilling
batch.prefilling_mask = next_prefilling_mask
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_current_length],
next_token_logits,
speculate,
batch.speculative_ids,
speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1
# pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
indices
]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices
]
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a HPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
# Cumulative length
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
if speculative_logits is not None:
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
dtype=torch.long,
device=batch.input_lengths_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
# Update values
# These values can be updated without a HPU -> CPU sync
if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids
# Does a HPU <-> CPU sync internally
if prefill and finished_prefilling:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
# HPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
# Update values if we need to continue prefilling
# This represents the `else` case of the `Update values` if above
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
if prefill and not finished_prefilling:
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert batch.speculative_ids is None
all_postfix_ids = []
for i, (
request_prefilling,
next_token_id,
all_input_ids,
cache_length,
input_length,
next_chunk_length,
) in enumerate(
zip(
batch.prefilling_mask,
next_token_ids,
batch.all_input_ids,
batch.cache_lengths,
batch.input_lengths,
next_chunk_lengths,
)
):
if request_prefilling:
next_cache_length = cache_length + input_length
# Get new prompt IDs to prefill
postfix_ids = all_input_ids[
next_cache_length : next_cache_length + next_chunk_length
]
else:
# This request is done prefilling, the new id is the one selected the sampling method
postfix_ids = [next_token_id]
all_postfix_ids.append(postfix_ids)
batch.input_ids = all_postfix_ids
start_decode = time.time_ns() start_decode = time.time_ns()
# Stage 3. Finish and return previous generations
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = len(requests_to_generate) > 0
# Zipped iterator
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
current_prefilling_mask,
batch.prefilling_mask,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
)
# Reset max_input_length # Reset max_input_length
batch.max_input_length = 0 batch.max_input_length = 0
# For each member of the batch # For each member of the batch
index = 0 indexs = [0] * len(prev_batches)
for i, ( idx_accept_ids = [0] * len(prev_batches)
request, for i, req_data in enumerate(requests_to_generate):
prompt_length, idx = req_data["idx"]
cache_length, request_id = req_data["request_id"]
input_length, cache_length = req_data["cache_length"]
prefix_offset, input_length = req_data["input_length"]
read_offset, prefix_offset = req_data["prefix_offset"]
stopping_criteria, read_offset = req_data["read_offset"]
all_input_ids, stopping_criteria = req_data["stopping_criteria"]
do_sample, all_input_ids = req_data["all_input_ids"]
seed, do_sample = req_data["do_sample"]
top_n_tokens, seed = req_data["seed"]
request_was_prefilling, top_n_tokens = req_data["top_n_tokens"]
request_is_prefilling, n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
n_accepted_ids, top_token_ids = req_data["top_token_ids"]
top_token_ids, top_token_logprobs = req_data["top_token_logprobs"]
top_token_logprobs,
) in enumerate(iterator):
# Compute logprobs first as, even though we might skip the token,
# it can still be required to compute the logprobs
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
# this state to be stable
if request.id % self.world_size == self.rank:
# Prefill
if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
if not request_is_prefilling:
# The request is dones prefilling, meaning that we started generating new tokens
# The last logprob is a logprob for a generated token that was not part of the prompt
# We need to remove it
out_end_index -= 1
request_prefill_logprobs = prefill_logprobs[ new_input_length = 1
out_start_index:out_end_index new_cache_length = cache_length + input_length + n_accepted_ids - 1
] # Append next token to all tokens
# Logprobs generated by the model are for the next token next_token_texts = []
# So we need to translate the id tensor by 1 left = 0
prefill_token_ids = all_input_ids[
cache_length + 1 : cache_length + input_length + 1
]
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] if n_accepted_ids > 1:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
if past_prefill_logprob_tokens is None: current_stopped = False
# add nan for cached prompt tokens/first token index = indexs[idx]
request_prefill_logprobs = [float("nan")] * ( for j in range(index, index + n_accepted_ids):
cache_length + 1 # Generated token
) + request_prefill_logprobs next_token_id = prev_batches[idx]["next_token_ids"][j]
prefill_token_ids = ( all_input_ids.append(next_token_id)
all_input_ids[: cache_length + 1] + prefill_token_ids next_token_text, prefix_offset, read_offset = self.decode_token(
) all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
prefill_texts = self.tokenizer.batch_decode( stop, reason = stopping_criteria(
prefill_token_ids, next_token_id,
clean_up_tokenization_spaces=False, next_token_text,
skip_special_tokens=False, )
)
prefill_logprob_tokens = Tokens( if stop:
prefill_token_ids, left = index + n_accepted_ids - j - 1
request_prefill_logprobs, current_stopped = True
prefill_texts, break
is_special=[],
)
if past_prefill_logprob_tokens is not None:
prefill_logprob_tokens = (
past_prefill_logprob_tokens + prefill_logprob_tokens
)
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
else: else:
batch.prefill_logprob_tokens[i] = None current_stopped = False
stopped = stopped and current_stopped
# If it is, the tokens we decoded should be ignored _next_token_ids = prev_batches[idx]["next_token_ids"][
if request_is_prefilling: index : index + n_accepted_ids - left
# Make sure that we do not stop as even though this request did not create a token, it is still ]
# processing _next_token_logprobs = prev_batches[idx]["next_token_logprobs"][
stopped = False index : index + n_accepted_ids - left
new_input_length = next_chunk_lengths[i] ]
new_cache_length = cache_length + input_length
else:
new_input_length = 1
new_cache_length = cache_length + input_length + n_accepted_ids - 1
# Append next token to all tokens
next_token_texts = []
left = 0
if n_accepted_ids > 1: # Shard generations
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") # All generations will be appended in the rust sharded client
if request_id % self.world_size == self.rank:
current_stopped = False if stop:
for j in range(index, index + n_accepted_ids): # Decode generated tokens
# Generated token output_text, _, _ = self.decode_token(
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids, all_input_ids,
prefix_offset, prefix_offset=len(all_input_ids)
read_offset, - stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
) )
next_token_texts.append(next_token_text) generated_text = GeneratedText(
output_text,
stop, reason = stopping_criteria( stopping_criteria.current_tokens,
next_token_id, reason,
next_token_text, seed if do_sample else None,
) )
else:
generated_text = None
if stop: if top_n_tokens > 0:
left = index + n_accepted_ids - j - 1 all_top_tokens = []
current_stopped = True for top_token_ids, top_token_logprobs in zip(
break top_token_ids, top_token_logprobs
else: ):
current_stopped = False toptoken_texts = self.tokenizer.batch_decode(
stopped = stopped and current_stopped top_token_ids,
clean_up_tokenization_spaces=False,
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left] skip_special_tokens=False,
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
# Shard generations
# All generations will be appended in the rust sharded client
if request.id % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
) )
generated_text = GeneratedText( special_toptokens = [
output_text, token_id in self.all_special_ids
stopping_criteria.current_tokens, for token_id in top_token_ids
reason, ]
seed if do_sample else None, top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
) )
else: all_top_tokens.append(top_tokens)
generated_text = None top_tokens = all_top_tokens
else:
top_tokens = None
if top_n_tokens > 0: generation = Generation(
all_top_tokens = [] request_id,
for top_token_ids, top_token_logprobs in zip( None,
top_token_ids, top_token_logprobs Tokens(
): _next_token_ids,
toptoken_texts = self.tokenizer.batch_decode( _next_token_logprobs,
top_token_ids, next_token_texts,
clean_up_tokenization_spaces=False, [nid in self.all_special_ids for nid in _next_token_ids],
skip_special_tokens=False, ),
) generated_text,
special_toptokens = [ top_tokens,
token_id in self.all_special_ids )
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation( generations.append(generation)
request.id,
batch.prefill_logprob_tokens[i],
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# accept each new token for this specific request since we may # accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding # have more than one new token per request with speculative decoding
@ -2231,7 +2163,8 @@ class FlashCausalLM(Model):
) )
# Update values # Update values
index += n_accepted_ids indexs[idx] += n_accepted_ids
idx_accept_ids[idx] += 1
batch.cache_lengths[i] = new_cache_length batch.cache_lengths[i] = new_cache_length
batch.max_input_length = max(batch.max_input_length, new_input_length) batch.max_input_length = max(batch.max_input_length, new_input_length)
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
@ -2248,14 +2181,6 @@ class FlashCausalLM(Model):
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns) return generations, None, (forward_ns, decode_ns)
if prefill and finished_prefilling:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)

View File

@ -456,6 +456,8 @@ class FlashVlmCausalLM(FlashCausalLM):
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets) reversed(self.bucketing_ctx.decode_buckets)
): ):
if batch_size > block_num:
continue
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )

View File

@ -368,6 +368,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets) reversed(self.bucketing_ctx.decode_buckets)
): ):
if batch_size > block_num:
continue
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )