pingpong optimization

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-08 19:56:10 -07:00
parent 29703dbd27
commit cd900c3b72
3 changed files with 274 additions and 345 deletions

View File

@ -253,6 +253,9 @@ class FlashCausalLMBatch(Batch):
hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
next_token_logits: Optional[torch.Tensor]
speculative_logits: Optional[torch.Tensor]
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
@ -490,6 +493,8 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor=None,
adapter_meta=None,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
)
@classmethod
@ -698,6 +703,8 @@ class FlashCausalLMBatch(Batch):
speculative_ids=speculative_ids,
adapter_meta=adapter_meta,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
)
@classmethod
@ -959,6 +966,8 @@ class FlashCausalLMBatch(Batch):
speculative_ids=speculative_ids,
adapter_meta=adapter_meta,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
)
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
@ -1484,7 +1493,7 @@ class FlashCausalLM(Model):
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None:
max_total_tokens = sum(batch.cache_lengths)
max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
@ -1531,6 +1540,8 @@ class FlashCausalLM(Model):
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
@ -1803,6 +1814,144 @@ class FlashCausalLM(Model):
def generate_token(
self, batches: List[FlashCausalLMBatch]
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
# Stage 1. Collect next token ids of any previously started generations
prev_batches = []
requests_to_generate = []
for batch_id, batch in enumerate(batches):
if batch.next_token_logits is not None:
prefill = batch.prefilling
if batch.prefilling:
batch.prefilling = False
batch.prefilling_mask = [False] * len(batch)
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_current_length],
batch.next_token_logits,
speculate,
batch.speculative_ids,
batch.speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
logprobs,
accepted_ids,
)
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill:
indices = batch.cu_seqlen_prefill[1:] - 1
# pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[
batch.prefill_cache_indices
][indices]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = (
batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch
# Cumulative length
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
if batch.speculative_logits is not None:
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
dtype=torch.long,
device=batch.input_lengths_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids
# Does a HPU <-> CPU sync internally
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices
)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
prev_batches.append(
{
"next_token_ids": next_input_ids,
"next_token_logprobs": next_token_logprobs,
"accepted_ids": accepted_ids,
}
)
idx = len(prev_batches) - 1
for req_idx, req in enumerate(batch.requests):
requests_to_generate.append(
{
"idx": idx,
"request_id": req.id,
"cache_length": batch.cache_lengths[req_idx],
"input_length": batch.input_lengths[req_idx],
"prefix_offset": batch.prefix_offsets[req_idx],
"read_offset": batch.read_offsets[req_idx],
"stopping_criteria": batch.stopping_criterias[req_idx],
"all_input_ids": batch.all_input_ids[req_idx],
"do_sample": batch.next_token_chooser.do_sample[req_idx],
"seed": batch.next_token_chooser.seeds[req_idx],
"top_n_tokens": batch.top_n_tokens[req_idx],
"top_token_ids": batch_top_token_ids[req_idx],
"top_token_logprobs": batch_top_token_logprobs[req_idx],
}
)
if prefill:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.next_token_logits = None
batch.speculative_ids = None
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
batch = self.batch_type.concatenate(batches)
else:
@ -1851,7 +2000,7 @@ class FlashCausalLM(Model):
out, speculative_logits = self.forward(batch, adapter_data)
if prefill:
next_token_logits = (
batch.next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
if speculative_logits is not None:
@ -1862,263 +2011,43 @@ class FlashCausalLM(Model):
)
else:
prefill_logprobs = None
next_token_logits = out
batch.next_token_logits = out
batch.speculative_logits = speculative_logits
finished_prefilling = True
next_chunk_lengths = []
current_prefilling_mask = batch.prefilling_mask
if prefill:
finished_prefilling = True
next_prefilling_mask = [False] * len(batch)
batch.prefilling = not finished_prefilling
batch.prefilling_mask = next_prefilling_mask
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_current_length],
next_token_logits,
speculate,
batch.speculative_ids,
speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1
# pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
indices
]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices
]
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a HPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
# Cumulative length
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
if speculative_logits is not None:
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
dtype=torch.long,
device=batch.input_lengths_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
# Update values
# These values can be updated without a HPU -> CPU sync
if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids
# Does a HPU <-> CPU sync internally
if prefill and finished_prefilling:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
# HPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
# Update values if we need to continue prefilling
# This represents the `else` case of the `Update values` if above
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
if prefill and not finished_prefilling:
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert batch.speculative_ids is None
all_postfix_ids = []
for i, (
request_prefilling,
next_token_id,
all_input_ids,
cache_length,
input_length,
next_chunk_length,
) in enumerate(
zip(
batch.prefilling_mask,
next_token_ids,
batch.all_input_ids,
batch.cache_lengths,
batch.input_lengths,
next_chunk_lengths,
)
):
if request_prefilling:
next_cache_length = cache_length + input_length
# Get new prompt IDs to prefill
postfix_ids = all_input_ids[
next_cache_length : next_cache_length + next_chunk_length
]
else:
# This request is done prefilling, the new id is the one selected the sampling method
postfix_ids = [next_token_id]
all_postfix_ids.append(postfix_ids)
batch.input_ids = all_postfix_ids
# HPU->CPU sync
for prev_batch in prev_batches:
prev_batch["next_token_logprobs"] = prev_batch[
"next_token_logprobs"
].tolist()
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
start_decode = time.time_ns()
# Stage 3. Finish and return previous generations
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
current_prefilling_mask,
batch.prefilling_mask,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
)
stopped = len(requests_to_generate) > 0
# Reset max_input_length
batch.max_input_length = 0
# For each member of the batch
index = 0
for i, (
request,
prompt_length,
cache_length,
input_length,
prefix_offset,
read_offset,
stopping_criteria,
all_input_ids,
do_sample,
seed,
top_n_tokens,
request_was_prefilling,
request_is_prefilling,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Compute logprobs first as, even though we might skip the token,
# it can still be required to compute the logprobs
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
# this state to be stable
if request.id % self.world_size == self.rank:
# Prefill
if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
if not request_is_prefilling:
# The request is dones prefilling, meaning that we started generating new tokens
# The last logprob is a logprob for a generated token that was not part of the prompt
# We need to remove it
out_end_index -= 1
indexs = [0] * len(prev_batches)
idx_accept_ids = [0] * len(prev_batches)
for i, req_data in enumerate(requests_to_generate):
idx = req_data["idx"]
request_id = req_data["request_id"]
cache_length = req_data["cache_length"]
input_length = req_data["input_length"]
prefix_offset = req_data["prefix_offset"]
read_offset = req_data["read_offset"]
stopping_criteria = req_data["stopping_criteria"]
all_input_ids = req_data["all_input_ids"]
do_sample = req_data["do_sample"]
seed = req_data["seed"]
top_n_tokens = req_data["top_n_tokens"]
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
top_token_ids = req_data["top_token_ids"]
top_token_logprobs = req_data["top_token_logprobs"]
request_prefill_logprobs = prefill_logprobs[
out_start_index:out_end_index
]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
prefill_token_ids = all_input_ids[
cache_length + 1 : cache_length + input_length + 1
]
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
if past_prefill_logprob_tokens is None:
# add nan for cached prompt tokens/first token
request_prefill_logprobs = [float("nan")] * (
cache_length + 1
) + request_prefill_logprobs
prefill_token_ids = (
all_input_ids[: cache_length + 1] + prefill_token_ids
)
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_logprob_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
)
if past_prefill_logprob_tokens is not None:
prefill_logprob_tokens = (
past_prefill_logprob_tokens + prefill_logprob_tokens
)
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
else:
batch.prefill_logprob_tokens[i] = None
# If it is, the tokens we decoded should be ignored
if request_is_prefilling:
# Make sure that we do not stop as even though this request did not create a token, it is still
# processing
stopped = False
new_input_length = next_chunk_lengths[i]
new_cache_length = cache_length + input_length
else:
new_input_length = 1
new_cache_length = cache_length + input_length + n_accepted_ids - 1
# Append next token to all tokens
@ -2129,9 +2058,10 @@ class FlashCausalLM(Model):
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
index = indexs[idx]
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
next_token_id = prev_batches[idx]["next_token_ids"][j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
@ -2153,14 +2083,16 @@ class FlashCausalLM(Model):
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
_next_token_ids = prev_batches[idx]["next_token_ids"][
index : index + n_accepted_ids - left
]
_next_token_logprobs = prev_batches[idx]["next_token_logprobs"][
index : index + n_accepted_ids - left
]
# Shard generations
# All generations will be appended in the rust sharded client
if request.id % self.world_size == self.rank:
if request_id % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
@ -2207,8 +2139,8 @@ class FlashCausalLM(Model):
top_tokens = None
generation = Generation(
request.id,
batch.prefill_logprob_tokens[i],
request_id,
None,
Tokens(
_next_token_ids,
_next_token_logprobs,
@ -2231,7 +2163,8 @@ class FlashCausalLM(Model):
)
# Update values
index += n_accepted_ids
indexs[idx] += n_accepted_ids
idx_accept_ids[idx] += 1
batch.cache_lengths[i] = new_cache_length
batch.max_input_length = max(batch.max_input_length, new_input_length)
batch.input_lengths[i] = new_input_length
@ -2248,14 +2181,6 @@ class FlashCausalLM(Model):
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
if prefill and finished_prefilling:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -456,6 +456,8 @@ class FlashVlmCausalLM(FlashCausalLM):
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)

View File

@ -368,6 +368,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)