From cd900c3b729c9fecea930b34ce26972d2e527e82 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 8 Apr 2025 19:56:10 -0700 Subject: [PATCH] pingpong optimization Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 615 ++++++++---------- .../models/flash_vlm_causal_lm.py | 2 + .../models/mllama_causal_lm.py | 2 + 3 files changed, 274 insertions(+), 345 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 23a40016..51adffc7 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -253,6 +253,9 @@ class FlashCausalLMBatch(Batch): hpu_attn_meta: Optional[HPUPagedAttentionMetadata] + next_token_logits: Optional[torch.Tensor] + speculative_logits: Optional[torch.Tensor] + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -490,6 +493,8 @@ class FlashCausalLMBatch(Batch): input_lengths_tensor=None, adapter_meta=None, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) @classmethod @@ -698,6 +703,8 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) @classmethod @@ -959,6 +966,8 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): @@ -1484,7 +1493,7 @@ class FlashCausalLM(Model): log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") if max_total_tokens is None: - max_total_tokens = sum(batch.cache_lengths) + max_total_tokens = sum(batch.input_lengths) if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 @@ -1531,6 +1540,8 @@ class FlashCausalLM(Model): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + if batch_size > block_num: + continue log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) @@ -1803,6 +1814,144 @@ class FlashCausalLM(Model): def generate_token( self, batches: List[FlashCausalLMBatch] ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + prev_batches = [] + requests_to_generate = [] + for batch_id, batch in enumerate(batches): + if batch.next_token_logits is not None: + prefill = batch.prefilling + if batch.prefilling: + batch.prefilling = False + batch.prefilling_mask = [False] * len(batch) + + speculate = get_speculate() + ( + next_input_ids, + next_token_logprobs, + logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_current_length], + batch.next_token_logits, + speculate, + batch.speculative_ids, + batch.speculative_logits, + ) + + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) + + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill: + indices = batch.cu_seqlen_prefill[1:] - 1 + # pad in left + if batch.prefill_cache_indices is not None: + batch.position_ids = batch.position_ids[ + batch.prefill_cache_indices + ][indices] + else: + batch.position_ids = batch.position_ids[indices] + + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = ( + batch.adapter_meta.adapter_indices[indices] + ) + # For each member of the batch + # Cumulative length + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + if batch.speculative_logits is not None: + for i in range(len(batch)): + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] + else: + index = batch.cache_lengths_tensor + batch.input_lengths_tensor + batch_idx = torch.arange( + 0, + batch.all_input_ids_tensor.shape[0], + dtype=torch.long, + device=batch.input_lengths_tensor.device, + ) + batch.all_input_ids_tensor.index_put_( + (batch_idx, index.long()), next_input_ids + ) + + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.speculative_ids = speculative_ids + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += ( + batch.input_lengths_tensor + accepted_ids - 1 + ) + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) + batch.slot_indices += accepted_ids + + # Does a HPU <-> CPU sync internally + if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments( + batch.adapter_meta.adapter_indices + ) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + prev_batches.append( + { + "next_token_ids": next_input_ids, + "next_token_logprobs": next_token_logprobs, + "accepted_ids": accepted_ids, + } + ) + idx = len(prev_batches) - 1 + + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append( + { + "idx": idx, + "request_id": req.id, + "cache_length": batch.cache_lengths[req_idx], + "input_length": batch.input_lengths[req_idx], + "prefix_offset": batch.prefix_offsets[req_idx], + "read_offset": batch.read_offsets[req_idx], + "stopping_criteria": batch.stopping_criterias[req_idx], + "all_input_ids": batch.all_input_ids[req_idx], + "do_sample": batch.next_token_chooser.do_sample[req_idx], + "seed": batch.next_token_chooser.seeds[req_idx], + "top_n_tokens": batch.top_n_tokens[req_idx], + "top_token_ids": batch_top_token_ids[req_idx], + "top_token_logprobs": batch_top_token_logprobs[req_idx], + } + ) + if prefill: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None + batch.next_token_logits = None + batch.speculative_ids = None + + htorch.core.mark_step() + # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: batch = self.batch_type.concatenate(batches) else: @@ -1851,7 +2000,7 @@ class FlashCausalLM(Model): out, speculative_logits = self.forward(batch, adapter_data) if prefill: - next_token_logits = ( + batch.next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) if speculative_logits is not None: @@ -1862,364 +2011,147 @@ class FlashCausalLM(Model): ) else: prefill_logprobs = None - next_token_logits = out + batch.next_token_logits = out + batch.speculative_logits = speculative_logits - finished_prefilling = True - next_chunk_lengths = [] - current_prefilling_mask = batch.prefilling_mask - if prefill: - finished_prefilling = True - next_prefilling_mask = [False] * len(batch) - - batch.prefilling = not finished_prefilling - batch.prefilling_mask = next_prefilling_mask - - speculate = get_speculate() - ( - next_input_ids, - next_token_logprobs, - logprobs, - accepted_ids, - speculative_ids, - ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], - next_token_logits, - speculate, - batch.speculative_ids, - speculative_logits, - ) - - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids - ) - - # Since we are done prefilling, all the tensors that were concatenating values for all the requests - # instantly become of shape [BATCH_SIZE] - if prefill and finished_prefilling: - indices = batch.cu_seqlen_prefill[1:] - 1 - # pad in left - if batch.prefill_cache_indices is not None: - batch.position_ids = batch.position_ids[batch.prefill_cache_indices][ - indices - ] - else: - batch.position_ids = batch.position_ids[indices] - - batch.slot_indices = batch.slot_indices[indices] - batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ - indices - ] - # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a HPU <-> CPU sync - # It is faster if we delay this sync for the maximum amount of time - - # For each member of the batch - # Cumulative length - cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) - torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - if speculative_logits is not None: - for i in range(len(batch)): - batch.all_input_ids_tensor[ - i, - batch.cache_lengths_tensor[i] - + batch.input_lengths[i] : batch.cache_lengths_tensor[i] - + batch.input_lengths[i] - + accepted_ids[i], - ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] - else: - index = batch.cache_lengths_tensor + batch.input_lengths_tensor - batch_idx = torch.arange( - 0, - batch.all_input_ids_tensor.shape[0], - dtype=torch.long, - device=batch.input_lengths_tensor.device, - ) - batch.all_input_ids_tensor.index_put_( - (batch_idx, index.long()), next_input_ids - ) - - # Update values - # These values can be updated without a HPU -> CPU sync - if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] - batch.speculative_ids = speculative_ids - if batch.position_ids.dim() == 2: - # Qwen2_vl case: - batch.position_ids += accepted_ids.unsqueeze(-1) - else: - batch.position_ids += accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 - batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) - batch.slot_indices += accepted_ids - - # Does a HPU <-> CPU sync internally - if prefill and finished_prefilling: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) - - # HPU <-> CPU sync - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids.tolist() - accepted_ids = accepted_ids.tolist() - - # Update values if we need to continue prefilling - # This represents the `else` case of the `Update values` if above - # but since this require the `next_token_ids` to be on CPU, it is better to do it here - if prefill and not finished_prefilling: - # Speculation must be ignored while we prefill even with chunking - # it simplifies everything - assert batch.speculative_ids is None - - all_postfix_ids = [] - for i, ( - request_prefilling, - next_token_id, - all_input_ids, - cache_length, - input_length, - next_chunk_length, - ) in enumerate( - zip( - batch.prefilling_mask, - next_token_ids, - batch.all_input_ids, - batch.cache_lengths, - batch.input_lengths, - next_chunk_lengths, - ) - ): - if request_prefilling: - next_cache_length = cache_length + input_length - # Get new prompt IDs to prefill - postfix_ids = all_input_ids[ - next_cache_length : next_cache_length + next_chunk_length - ] - else: - # This request is done prefilling, the new id is the one selected the sampling method - postfix_ids = [next_token_id] - - all_postfix_ids.append(postfix_ids) - - batch.input_ids = all_postfix_ids + # HPU->CPU sync + for prev_batch in prev_batches: + prev_batch["next_token_logprobs"] = prev_batch[ + "next_token_logprobs" + ].tolist() + prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist() + prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist() start_decode = time.time_ns() - + # Stage 3. Finish and return previous generations # Results generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - batch.stopping_criterias, - batch.all_input_ids, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - batch.top_n_tokens, - current_prefilling_mask, - batch.prefilling_mask, - accepted_ids, - batch_top_token_ids, - batch_top_token_logprobs, - ) - + stopped = len(requests_to_generate) > 0 # Reset max_input_length batch.max_input_length = 0 # For each member of the batch - index = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - top_n_tokens, - request_was_prefilling, - request_is_prefilling, - n_accepted_ids, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Compute logprobs first as, even though we might skip the token, - # it can still be required to compute the logprobs - # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need - # this state to be stable - if request.id % self.world_size == self.rank: - # Prefill - if request_was_prefilling and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - if not request_is_prefilling: - # The request is dones prefilling, meaning that we started generating new tokens - # The last logprob is a logprob for a generated token that was not part of the prompt - # We need to remove it - out_end_index -= 1 + indexs = [0] * len(prev_batches) + idx_accept_ids = [0] * len(prev_batches) + for i, req_data in enumerate(requests_to_generate): + idx = req_data["idx"] + request_id = req_data["request_id"] + cache_length = req_data["cache_length"] + input_length = req_data["input_length"] + prefix_offset = req_data["prefix_offset"] + read_offset = req_data["read_offset"] + stopping_criteria = req_data["stopping_criteria"] + all_input_ids = req_data["all_input_ids"] + do_sample = req_data["do_sample"] + seed = req_data["seed"] + top_n_tokens = req_data["top_n_tokens"] + n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]] + top_token_ids = req_data["top_token_ids"] + top_token_logprobs = req_data["top_token_logprobs"] - request_prefill_logprobs = prefill_logprobs[ - out_start_index:out_end_index - ] - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - prefill_token_ids = all_input_ids[ - cache_length + 1 : cache_length + input_length + 1 - ] + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 + # Append next token to all tokens + next_token_texts = [] + left = 0 - past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - if past_prefill_logprob_tokens is None: - # add nan for cached prompt tokens/first token - request_prefill_logprobs = [float("nan")] * ( - cache_length + 1 - ) + request_prefill_logprobs - prefill_token_ids = ( - all_input_ids[: cache_length + 1] + prefill_token_ids - ) + current_stopped = False + index = indexs[idx] + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = prev_batches[idx]["next_token_ids"][j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - prefill_logprob_tokens = Tokens( - prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - is_special=[], - ) - if past_prefill_logprob_tokens is not None: - prefill_logprob_tokens = ( - past_prefill_logprob_tokens + prefill_logprob_tokens - ) - - batch.prefill_logprob_tokens[i] = prefill_logprob_tokens + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break else: - batch.prefill_logprob_tokens[i] = None + current_stopped = False + stopped = stopped and current_stopped - # If it is, the tokens we decoded should be ignored - if request_is_prefilling: - # Make sure that we do not stop as even though this request did not create a token, it is still - # processing - stopped = False - new_input_length = next_chunk_lengths[i] - new_cache_length = cache_length + input_length - else: - new_input_length = 1 - new_cache_length = cache_length + input_length + n_accepted_ids - 1 - # Append next token to all tokens - next_token_texts = [] - left = 0 + _next_token_ids = prev_batches[idx]["next_token_ids"][ + index : index + n_accepted_ids - left + ] + _next_token_logprobs = prev_batches[idx]["next_token_logprobs"][ + index : index + n_accepted_ids - left + ] - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( + # Shard generations + # All generations will be appended in the rust sharded client + if request_id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( all_input_ids, - prefix_offset, - read_offset, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) + else: + generated_text = None - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - - # Shard generations - # All generations will be appended in the rust sharded client - if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, ) - else: - generated_text = None + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None + generation = Generation( + request_id, + None, + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) - generation = Generation( - request.id, - batch.prefill_logprob_tokens[i], - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) + generations.append(generation) # accept each new token for this specific request since we may # have more than one new token per request with speculative decoding @@ -2231,7 +2163,8 @@ class FlashCausalLM(Model): ) # Update values - index += n_accepted_ids + indexs[idx] += n_accepted_ids + idx_accept_ids[idx] += 1 batch.cache_lengths[i] = new_cache_length batch.max_input_length = max(batch.max_input_length, new_input_length) batch.input_lengths[i] = new_input_length @@ -2248,14 +2181,6 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - if prefill and finished_prefilling: - # We do not need prefill tensors anymore - batch.cu_seqlen_prefill = None - batch.prefill_cache_indices = None - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None - forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 725e7517..cdda751a 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -456,6 +456,8 @@ class FlashVlmCausalLM(FlashCausalLM): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + if batch_size > block_num: + continue log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 940ee1b0..d21cc39d 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -368,6 +368,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + if batch_size > block_num: + continue log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" )