mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Some more simplification, fix flash_neox cu_seqlen pruning
This commit is contained in:
parent
f786d1ddf5
commit
65ff6a73b3
@ -281,7 +281,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
#TODO maybe a single loop for all these list slices
|
#TODO maybe a single loop for all these list slices
|
||||||
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
|
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
|
||||||
batch.input_lengths = slice_list(batch.input_lengths)
|
batch.input_lengths = list(slice_list(batch.input_lengths))
|
||||||
batch.requests = slice_list(batch.requests)
|
batch.requests = slice_list(batch.requests)
|
||||||
batch.all_input_ids = slice_list(batch.all_input_ids)
|
batch.all_input_ids = slice_list(batch.all_input_ids)
|
||||||
batch.next_token_choosers = slice_list(batch.next_token_choosers)
|
batch.next_token_choosers = slice_list(batch.next_token_choosers)
|
||||||
@ -366,7 +366,6 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# New values for next forward
|
# New values for next forward
|
||||||
next_batch_input_lengths = []
|
|
||||||
next_batch_input_ids = []
|
next_batch_input_ids = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
|
|
||||||
@ -405,7 +404,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
next_batch_input_ids.append(next_token_id)
|
next_batch_input_ids.append(next_token_id)
|
||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_input_lengths.append(new_input_length)
|
batch.input_lengths[i] = new_input_length
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
@ -437,7 +436,6 @@ class CausalLM(Model):
|
|||||||
batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
|
batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
batch.all_input_ids = next_batch_all_input_ids
|
batch.all_input_ids = next_batch_all_input_ids
|
||||||
batch.input_lengths = next_batch_input_lengths
|
|
||||||
batch.max_input_length += 1
|
batch.max_input_length += 1
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ class FlashNeoXBatch(Batch):
|
|||||||
|
|
||||||
#TODO maybe a single loop for all these list slices
|
#TODO maybe a single loop for all these list slices
|
||||||
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
|
slice_list = itemgetter(*keep_indices) if new_size > 1 else lambda l: (l[keep_indices[0]],)
|
||||||
batch.input_lengths = slice_list(batch.input_lengths)
|
batch.input_lengths = list(slice_list(batch.input_lengths))
|
||||||
batch.requests = slice_list(batch.requests)
|
batch.requests = slice_list(batch.requests)
|
||||||
batch.all_input_ids = slice_list(batch.all_input_ids)
|
batch.all_input_ids = slice_list(batch.all_input_ids)
|
||||||
batch.next_token_choosers = slice_list(batch.next_token_choosers)
|
batch.next_token_choosers = slice_list(batch.next_token_choosers)
|
||||||
@ -221,7 +221,11 @@ class FlashNeoXBatch(Batch):
|
|||||||
batch.position_ids = batch.position_ids[keep_indices]
|
batch.position_ids = batch.position_ids[keep_indices]
|
||||||
batch.past_key_values = batch.past_key_values[:, keep_indices] \
|
batch.past_key_values = batch.past_key_values[:, keep_indices] \
|
||||||
if batch.past_key_values is not None else None
|
if batch.past_key_values is not None else None
|
||||||
batch.cu_seqlens = batch.cu_seqlens[keep_indices]
|
|
||||||
|
# Recalculate cumulative seq lengths
|
||||||
|
new_cu_seqlens = batch.cu_seqlens.new_tensor(batch.input_lengths)
|
||||||
|
torch.cumsum(new_cu_seqlens, dim=0, out=new_cu_seqlens)
|
||||||
|
batch.cu_seqlens = torch.cat((batch.cu_seqlens[:1], new_cu_seqlens))
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -300,7 +304,6 @@ class FlashNeoX(Model):
|
|||||||
next_batch_position_ids = []
|
next_batch_position_ids = []
|
||||||
next_batch_cu_seqlens = [0]
|
next_batch_cu_seqlens = [0]
|
||||||
next_batch_past_key_values = []
|
next_batch_past_key_values = []
|
||||||
next_batch_input_lengths = []
|
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
@ -365,7 +368,7 @@ class FlashNeoX(Model):
|
|||||||
next_batch_cu_seqlens.append(
|
next_batch_cu_seqlens.append(
|
||||||
next_batch_cu_seqlens[-1] + new_input_length
|
next_batch_cu_seqlens[-1] + new_input_length
|
||||||
)
|
)
|
||||||
next_batch_input_lengths.append(new_input_length)
|
batch.input_lengths[i] = new_input_length
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
@ -406,7 +409,6 @@ class FlashNeoX(Model):
|
|||||||
batch.cu_seqlens = next_batch_cu_seqlens
|
batch.cu_seqlens = next_batch_cu_seqlens
|
||||||
batch.max_seqlen += 1
|
batch.max_seqlen += 1
|
||||||
batch.past_key_values = next_batch_past_key_values
|
batch.past_key_values = next_batch_past_key_values
|
||||||
batch.input_lengths = next_batch_input_lengths
|
|
||||||
|
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
|
@ -444,7 +444,6 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
# New values for next forward
|
# New values for next forward
|
||||||
next_batch_decoder_input_ids = []
|
next_batch_decoder_input_ids = []
|
||||||
next_batch_decoder_input_lengths = []
|
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -452,7 +451,6 @@ class Seq2SeqLM(Model):
|
|||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.decoder_input_lengths,
|
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.decoder_input_ids,
|
batch.decoder_input_ids,
|
||||||
@ -461,7 +459,6 @@ class Seq2SeqLM(Model):
|
|||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
decoder_input_length,
|
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
@ -479,7 +476,7 @@ class Seq2SeqLM(Model):
|
|||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
|
|
||||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||||
next_batch_decoder_input_lengths.append(decoder_input_length + 1)
|
batch.decoder_input_lengths[i] += 1
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
@ -508,7 +505,6 @@ class Seq2SeqLM(Model):
|
|||||||
batch.decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
batch.decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
||||||
batch.encoder_last_hidden_state = encoder_last_hidden_state
|
batch.encoder_last_hidden_state = encoder_last_hidden_state
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
batch.decoder_input_lengths = next_batch_decoder_input_lengths
|
|
||||||
batch.max_decoder_input_length += 1
|
batch.max_decoder_input_length += 1
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user