mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
optimize code
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
705cc0b619
commit
a84da5b698
@ -328,6 +328,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
### Deactivating it by default seems like the best course.
|
### Deactivating it by default seems like the best course.
|
||||||
if not REQUEST_LOGPROBS:
|
if not REQUEST_LOGPROBS:
|
||||||
r.prefill_logprobs = False
|
r.prefill_logprobs = False
|
||||||
|
else:
|
||||||
|
assert False, "prefill_logprobs not supported yet"
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
@ -1847,10 +1849,6 @@ class FlashCausalLM(Model):
|
|||||||
if prefill_logprobs
|
if prefill_logprobs
|
||||||
else speculative_logits
|
else speculative_logits
|
||||||
)
|
)
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
|
||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
@ -1900,19 +1898,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||||
indices
|
indices
|
||||||
]
|
]
|
||||||
|
|
||||||
# Zipped iterator
|
|
||||||
iterator = zip(
|
|
||||||
batch.requests,
|
|
||||||
batch.prompt_lengths,
|
|
||||||
batch.cache_lengths,
|
|
||||||
batch.input_lengths,
|
|
||||||
batch.all_input_ids,
|
|
||||||
accepted_ids,
|
|
||||||
current_prefilling_mask,
|
|
||||||
batch.prefilling_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||||
# one, we need to first do a HPU <-> CPU sync
|
# one, we need to first do a HPU <-> CPU sync
|
||||||
# It is faster if we delay this sync for the maximum amount of time
|
# It is faster if we delay this sync for the maximum amount of time
|
||||||
@ -1921,38 +1906,8 @@ class FlashCausalLM(Model):
|
|||||||
# Cumulative length
|
# Cumulative length
|
||||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||||
cumulative_length = 0
|
if speculative_logits is not None:
|
||||||
for i, (
|
for i in range(len(batch)):
|
||||||
request,
|
|
||||||
prompt_length,
|
|
||||||
cache_length,
|
|
||||||
input_length,
|
|
||||||
all_input_ids,
|
|
||||||
n_accepted_ids,
|
|
||||||
request_was_prefilling,
|
|
||||||
request_is_prefilling,
|
|
||||||
) in enumerate(iterator):
|
|
||||||
# Used to gather prefill logprobs
|
|
||||||
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
|
||||||
if request.prefill_logprobs and request_was_prefilling:
|
|
||||||
# Indexing metadata
|
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
|
||||||
|
|
||||||
# Logprobs generated by the model are for the next token
|
|
||||||
# So we need to translate the id tensor by 1
|
|
||||||
ids = batch.all_input_ids_tensor[
|
|
||||||
i, cache_length + 1 : cache_length + input_length + 1
|
|
||||||
]
|
|
||||||
if len(batch) > 1:
|
|
||||||
prefill_tokens_indices[out_start_index:out_end_index] = ids
|
|
||||||
else:
|
|
||||||
# Set prefill_tokens_indices to the correct slice
|
|
||||||
prefill_tokens_indices = ids
|
|
||||||
|
|
||||||
# If the device does not support triton, we copy one by one
|
|
||||||
if not request_is_prefilling:
|
|
||||||
# Only save tokens if we are done prefilling for this request
|
|
||||||
batch.all_input_ids_tensor[
|
batch.all_input_ids_tensor[
|
||||||
i,
|
i,
|
||||||
batch.cache_lengths_tensor[i]
|
batch.cache_lengths_tensor[i]
|
||||||
@ -1960,7 +1915,17 @@ class FlashCausalLM(Model):
|
|||||||
+ batch.input_lengths[i]
|
+ batch.input_lengths[i]
|
||||||
+ accepted_ids[i],
|
+ accepted_ids[i],
|
||||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||||
cumulative_length += input_length
|
else:
|
||||||
|
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||||
|
batch_idx = torch.arange(
|
||||||
|
0,
|
||||||
|
batch.all_input_ids_tensor.shape[0],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=batch.input_lengths_tensor.device,
|
||||||
|
)
|
||||||
|
batch.all_input_ids_tensor.index_put_(
|
||||||
|
(batch_idx, index.long()), next_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
# These values can be updated without a HPU -> CPU sync
|
# These values can be updated without a HPU -> CPU sync
|
||||||
@ -1976,16 +1941,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
|
||||||
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
|
|
||||||
torch.log_softmax(out, -1, out=out)
|
|
||||||
prefill_logprobs_tensor = out
|
|
||||||
prefill_logprobs = torch.gather(
|
|
||||||
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
|
|
||||||
)
|
|
||||||
# HPU <-> CPU sync
|
|
||||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
|
||||||
|
|
||||||
# Does a HPU <-> CPU sync internally
|
# Does a HPU <-> CPU sync internally
|
||||||
if prefill and finished_prefilling:
|
if prefill and finished_prefilling:
|
||||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||||
|
Loading…
Reference in New Issue
Block a user