mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Some more simplification, fix flash_neox cu_seqlen pruning
This commit is contained in:
parent
f786d1ddf5
commit
65ff6a73b3
@ -281,7 +281,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
#TODO maybe a single loop for all these list slices
|
||||
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
|
||||
batch.input_lengths = slice_list(batch.input_lengths)
|
||||
batch.input_lengths = list(slice_list(batch.input_lengths))
|
||||
batch.requests = slice_list(batch.requests)
|
||||
batch.all_input_ids = slice_list(batch.all_input_ids)
|
||||
batch.next_token_choosers = slice_list(batch.next_token_choosers)
|
||||
@ -366,7 +366,6 @@ class CausalLM(Model):
|
||||
)
|
||||
|
||||
# New values for next forward
|
||||
next_batch_input_lengths = []
|
||||
next_batch_input_ids = []
|
||||
next_batch_all_input_ids = []
|
||||
|
||||
@ -405,7 +404,7 @@ class CausalLM(Model):
|
||||
|
||||
next_batch_input_ids.append(next_token_id)
|
||||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
batch.input_lengths[i] = new_input_length
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
@ -437,7 +436,6 @@ class CausalLM(Model):
|
||||
batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||
batch.past_key_values = past
|
||||
batch.all_input_ids = next_batch_all_input_ids
|
||||
batch.input_lengths = next_batch_input_lengths
|
||||
batch.max_input_length += 1
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
|
@ -209,7 +209,7 @@ class FlashNeoXBatch(Batch):
|
||||
|
||||
#TODO maybe a single loop for all these list slices
|
||||
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
|
||||
batch.input_lengths = slice_list(batch.input_lengths)
|
||||
batch.input_lengths = list(slice_list(batch.input_lengths))
|
||||
batch.requests = slice_list(batch.requests)
|
||||
batch.all_input_ids = slice_list(batch.all_input_ids)
|
||||
batch.next_token_choosers = slice_list(batch.next_token_choosers)
|
||||
@ -221,7 +221,11 @@ class FlashNeoXBatch(Batch):
|
||||
batch.position_ids = batch.position_ids[keep_indices]
|
||||
batch.past_key_values = batch.past_key_values[:, keep_indices] \
|
||||
if batch.past_key_values is not None else None
|
||||
batch.cu_seqlens = batch.cu_seqlens[keep_indices]
|
||||
|
||||
# Recalculate cumulative seq lengths
|
||||
new_cu_seqlens = batch.cu_seqlens.new_tensor(batch.input_lengths)
|
||||
torch.cumsum(new_cu_seqlens, dim=0, out=new_cu_seqlens)
|
||||
batch.cu_seqlens = torch.cat((batch.cu_seqlens[:1], new_cu_seqlens))
|
||||
|
||||
return batch
|
||||
|
||||
@ -300,7 +304,6 @@ class FlashNeoX(Model):
|
||||
next_batch_position_ids = []
|
||||
next_batch_cu_seqlens = [0]
|
||||
next_batch_past_key_values = []
|
||||
next_batch_input_lengths = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
@ -365,7 +368,7 @@ class FlashNeoX(Model):
|
||||
next_batch_cu_seqlens.append(
|
||||
next_batch_cu_seqlens[-1] + new_input_length
|
||||
)
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
batch.input_lengths[i] = new_input_length
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
@ -406,7 +409,6 @@ class FlashNeoX(Model):
|
||||
batch.cu_seqlens = next_batch_cu_seqlens
|
||||
batch.max_seqlen += 1
|
||||
batch.past_key_values = next_batch_past_key_values
|
||||
batch.input_lengths = next_batch_input_lengths
|
||||
|
||||
return generations
|
||||
|
||||
|
@ -444,7 +444,6 @@ class Seq2SeqLM(Model):
|
||||
|
||||
# New values for next forward
|
||||
next_batch_decoder_input_ids = []
|
||||
next_batch_decoder_input_lengths = []
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
@ -452,7 +451,6 @@ class Seq2SeqLM(Model):
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.decoder_input_lengths,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.decoder_input_ids,
|
||||
@ -461,7 +459,6 @@ class Seq2SeqLM(Model):
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
decoder_input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
decoder_input_ids,
|
||||
@ -479,7 +476,7 @@ class Seq2SeqLM(Model):
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_decoder_input_lengths.append(decoder_input_length + 1)
|
||||
batch.decoder_input_lengths[i] += 1
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
@ -508,7 +505,6 @@ class Seq2SeqLM(Model):
|
||||
batch.decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
||||
batch.encoder_last_hidden_state = encoder_last_hidden_state
|
||||
batch.past_key_values = past
|
||||
batch.decoder_input_lengths = next_batch_decoder_input_lengths
|
||||
batch.max_decoder_input_length += 1
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user