mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
optimize code
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
705cc0b619
commit
a84da5b698
@ -328,6 +328,8 @@ class FlashCausalLMBatch(Batch):
|
||||
### Deactivating it by default seems like the best course.
|
||||
if not REQUEST_LOGPROBS:
|
||||
r.prefill_logprobs = False
|
||||
else:
|
||||
assert False, "prefill_logprobs not supported yet"
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
@ -1847,10 +1849,6 @@ class FlashCausalLM(Model):
|
||||
if prefill_logprobs
|
||||
else speculative_logits
|
||||
)
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_token_logits = out
|
||||
@ -1900,19 +1898,6 @@ class FlashCausalLM(Model):
|
||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||
indices
|
||||
]
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.prompt_lengths,
|
||||
batch.cache_lengths,
|
||||
batch.input_lengths,
|
||||
batch.all_input_ids,
|
||||
accepted_ids,
|
||||
current_prefilling_mask,
|
||||
batch.prefilling_mask,
|
||||
)
|
||||
|
||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||
# one, we need to first do a HPU <-> CPU sync
|
||||
# It is faster if we delay this sync for the maximum amount of time
|
||||
@ -1921,38 +1906,8 @@ class FlashCausalLM(Model):
|
||||
# Cumulative length
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
cumulative_length = 0
|
||||
for i, (
|
||||
request,
|
||||
prompt_length,
|
||||
cache_length,
|
||||
input_length,
|
||||
all_input_ids,
|
||||
n_accepted_ids,
|
||||
request_was_prefilling,
|
||||
request_is_prefilling,
|
||||
) in enumerate(iterator):
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
||||
if request.prefill_logprobs and request_was_prefilling:
|
||||
# Indexing metadata
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
# Logprobs generated by the model are for the next token
|
||||
# So we need to translate the id tensor by 1
|
||||
ids = batch.all_input_ids_tensor[
|
||||
i, cache_length + 1 : cache_length + input_length + 1
|
||||
]
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[out_start_index:out_end_index] = ids
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = ids
|
||||
|
||||
# If the device does not support triton, we copy one by one
|
||||
if not request_is_prefilling:
|
||||
# Only save tokens if we are done prefilling for this request
|
||||
if speculative_logits is not None:
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
i,
|
||||
batch.cache_lengths_tensor[i]
|
||||
@ -1960,7 +1915,17 @@ class FlashCausalLM(Model):
|
||||
+ batch.input_lengths[i]
|
||||
+ accepted_ids[i],
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
cumulative_length += input_length
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
dtype=torch.long,
|
||||
device=batch.input_lengths_tensor.device,
|
||||
)
|
||||
batch.all_input_ids_tensor.index_put_(
|
||||
(batch_idx, index.long()), next_input_ids
|
||||
)
|
||||
|
||||
# Update values
|
||||
# These values can be updated without a HPU -> CPU sync
|
||||
@ -1976,16 +1941,6 @@ class FlashCausalLM(Model):
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
batch.slot_indices += accepted_ids
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
|
||||
torch.log_softmax(out, -1, out=out)
|
||||
prefill_logprobs_tensor = out
|
||||
prefill_logprobs = torch.gather(
|
||||
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
|
||||
)
|
||||
# HPU <-> CPU sync
|
||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||
|
||||
# Does a HPU <-> CPU sync internally
|
||||
if prefill and finished_prefilling:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
|
Loading…
Reference in New Issue
Block a user