optimize code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-02 00:56:15 -07:00
parent 705cc0b619
commit a84da5b698

View File

@ -328,6 +328,8 @@ class FlashCausalLMBatch(Batch):
### Deactivating it by default seems like the best course. ### Deactivating it by default seems like the best course.
if not REQUEST_LOGPROBS: if not REQUEST_LOGPROBS:
r.prefill_logprobs = False r.prefill_logprobs = False
else:
assert False, "prefill_logprobs not supported yet"
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
@ -1847,10 +1849,6 @@ class FlashCausalLM(Model):
if prefill_logprobs if prefill_logprobs
else speculative_logits else speculative_logits
) )
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
else: else:
prefill_logprobs = None prefill_logprobs = None
next_token_logits = out next_token_logits = out
@ -1900,19 +1898,6 @@ class FlashCausalLM(Model):
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices indices
] ]
# Zipped iterator
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
batch.all_input_ids,
accepted_ids,
current_prefilling_mask,
batch.prefilling_mask,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a HPU <-> CPU sync # one, we need to first do a HPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time # It is faster if we delay this sync for the maximum amount of time
@ -1921,38 +1906,8 @@ class FlashCausalLM(Model):
# Cumulative length # Cumulative length
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
cumulative_length = 0 if speculative_logits is not None:
for i, ( for i in range(len(batch)):
request,
prompt_length,
cache_length,
input_length,
all_input_ids,
n_accepted_ids,
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
# Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[
i, cache_length + 1 : cache_length + input_length + 1
]
if len(batch) > 1:
prefill_tokens_indices[out_start_index:out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids
# If the device does not support triton, we copy one by one
if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
batch.all_input_ids_tensor[ batch.all_input_ids_tensor[
i, i,
batch.cache_lengths_tensor[i] batch.cache_lengths_tensor[i]
@ -1960,7 +1915,17 @@ class FlashCausalLM(Model):
+ batch.input_lengths[i] + batch.input_lengths[i]
+ accepted_ids[i], + accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
cumulative_length += input_length else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
dtype=torch.long,
device=batch.input_lengths_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
# Update values # Update values
# These values can be updated without a HPU -> CPU sync # These values can be updated without a HPU -> CPU sync
@ -1976,16 +1941,6 @@ class FlashCausalLM(Model):
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
if prefill and prefill_logprobs:
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
torch.log_softmax(out, -1, out=out)
prefill_logprobs_tensor = out
prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
)
# HPU <-> CPU sync
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# Does a HPU <-> CPU sync internally # Does a HPU <-> CPU sync internally
if prefill and finished_prefilling: if prefill and finished_prefilling:
# adjust segment lengths to account for all request lengths being 1 during decoding # adjust segment lengths to account for all request lengths being 1 during decoding