fix decode timing

This commit is contained in:
OlivierDehaene 2023-12-14 12:41:10 +01:00
parent 701dd7da67
commit 248eda7b20
5 changed files with 34 additions and 26 deletions

View File

@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch = default_seq2seq_lm.generate_token(
generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -228,11 +228,11 @@ def test_batch_concatenate(
default_multi_requests_seq2seq_lm_batch,
):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
@ -324,10 +324,10 @@ def test_batch_concatenate(
)
for _ in range(3):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
@ -342,7 +342,7 @@ def test_batch_concatenate(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -352,7 +352,7 @@ def test_batch_concatenate(
next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -586,7 +586,7 @@ class CausalLM(Model):
torch.log_softmax(logits[:, -1], -1),
)
forward_ns = time.time_ns() - start
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
@ -734,7 +734,8 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
@ -751,5 +752,6 @@ class CausalLM(Model):
# Update past key values
batch.past_key_values = past
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -943,7 +943,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
forward_ns = time.time_ns() - start
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
@ -1105,12 +1105,14 @@ class FlashCausalLM(Model):
if stopped:
del batch
# No need to return a batch if we know that all requests stopped
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -694,7 +694,7 @@ class IdeficsCausalLM(Model):
# Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
forward_ns = time.time_ns() - start
start_decode = time.time_ns()
# Results
generations: List[Generation] = []
@ -824,7 +824,8 @@ class IdeficsCausalLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
@ -845,5 +846,6 @@ class IdeficsCausalLM(Model):
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -646,7 +646,7 @@ class Seq2SeqLM(Model):
torch.log_softmax(logits[:, -1], -1),
)
forward_ns = time.time_ns() - start
start_decode = time.time_ns()
# Finished requests
generations: List[Generation] = []
@ -792,7 +792,8 @@ class Seq2SeqLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# We don't need input_ids after the prefill forward
@ -804,5 +805,6 @@ class Seq2SeqLM(Model):
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
decode_ns = time.time_ns() - start
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)