From a84da5b698d6c20c5b18d983f1fc8de959d6635d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 2 Apr 2025 00:56:15 -0700 Subject: [PATCH] optimize code Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 75 ++++--------------- 1 file changed, 15 insertions(+), 60 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 48165256c..52a2ea613 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -328,6 +328,8 @@ class FlashCausalLMBatch(Batch): ### Deactivating it by default seems like the best course. if not REQUEST_LOGPROBS: r.prefill_logprobs = False + else: + assert False, "prefill_logprobs not supported yet" # request id -> idx in list mapping requests_idx_mapping[r.id] = i @@ -1847,10 +1849,6 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: prefill_logprobs = None next_token_logits = out @@ -1900,19 +1898,6 @@ class FlashCausalLM(Model): batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices ] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.all_input_ids, - accepted_ids, - current_prefilling_mask, - batch.prefilling_mask, - ) - # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a HPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time @@ -1921,38 +1906,8 @@ class FlashCausalLM(Model): # Cumulative length cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - cumulative_length = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - all_input_ids, - n_accepted_ids, - request_was_prefilling, - request_is_prefilling, - ) in enumerate(iterator): - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[ - i, cache_length + 1 : cache_length + input_length + 1 - ] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - # If the device does not support triton, we copy one by one - if not request_is_prefilling: - # Only save tokens if we are done prefilling for this request + if speculative_logits is not None: + for i in range(len(batch)): batch.all_input_ids_tensor[ i, batch.cache_lengths_tensor[i] @@ -1960,7 +1915,17 @@ class FlashCausalLM(Model): + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] - cumulative_length += input_length + else: + index = batch.cache_lengths_tensor + batch.input_lengths_tensor + batch_idx = torch.arange( + 0, + batch.all_input_ids_tensor.shape[0], + dtype=torch.long, + device=batch.input_lengths_tensor.device, + ) + batch.all_input_ids_tensor.index_put_( + (batch_idx, index.long()), next_input_ids + ) # Update values # These values can be updated without a HPU -> CPU sync @@ -1976,16 +1941,6 @@ class FlashCausalLM(Model): batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids - if prefill and prefill_logprobs: - # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) - torch.log_softmax(out, -1, out=out) - prefill_logprobs_tensor = out - prefill_logprobs = torch.gather( - prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) - ) - # HPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - # Does a HPU <-> CPU sync internally if prefill and finished_prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding