fix decode timing

This commit is contained in:
OlivierDehaene 2023-12-14 12:41:10 +01:00
parent 701dd7da67
commit 248eda7b20
5 changed files with 34 additions and 26 deletions

View File

@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch = default_seq2seq_lm.generate_token( generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch default_seq2seq_lm_batch
) )
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
): ):
next_batch = default_seq2seq_lm_batch next_batch = default_seq2seq_lm_batch
for _ in range(6): for _ in range(6):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4): for i in range(4):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0].id]) next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -228,11 +228,11 @@ def test_batch_concatenate(
default_multi_requests_seq2seq_lm_batch, default_multi_requests_seq2seq_lm_batch,
): ):
next_batch_0 = default_seq2seq_lm_batch next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1) _, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches # Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
@ -324,10 +324,10 @@ def test_batch_concatenate(
) )
for _ in range(3): for _ in range(3):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
@ -342,7 +342,7 @@ def test_batch_concatenate(
[next_batch.requests[0].id, next_batch.requests[1].id] [next_batch.requests[0].id, next_batch.requests[1].id]
) )
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -352,7 +352,7 @@ def test_batch_concatenate(
next_batch = next_batch.filter([next_batch.requests[1].id]) next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1

View File

@ -586,7 +586,7 @@ class CausalLM(Model):
torch.log_softmax(logits[:, -1], -1), torch.log_softmax(logits[:, -1], -1),
) )
forward_ns = time.time_ns() - start start_decode = time.time_ns()
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -734,7 +734,8 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if stopped: if stopped:
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns) return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill # Slice unused values from prefill
@ -751,5 +752,6 @@ class CausalLM(Model):
# Update past key values # Update past key values
batch.past_key_values = past batch.past_key_values = past
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)

View File

@ -943,7 +943,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync # GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist() next_token_ids = next_input_ids.tolist()
forward_ns = time.time_ns() - start start_decode = time.time_ns()
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -1105,12 +1105,14 @@ class FlashCausalLM(Model):
if stopped: if stopped:
del batch del batch
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns) return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)

View File

@ -694,7 +694,7 @@ class IdeficsCausalLM(Model):
# Hardcoded remove image tokens # Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min logits[:, 32000:32001] = torch.finfo(logits.dtype).min
forward_ns = time.time_ns() - start start_decode = time.time_ns()
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -824,7 +824,8 @@ class IdeficsCausalLM(Model):
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if stopped: if stopped:
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns) return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill # Slice unused values from prefill
@ -845,5 +846,6 @@ class IdeficsCausalLM(Model):
batch.past_key_values = past batch.past_key_values = past
batch.image_hidden_states = image_hidden_states batch.image_hidden_states = image_hidden_states
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)

View File

@ -646,7 +646,7 @@ class Seq2SeqLM(Model):
torch.log_softmax(logits[:, -1], -1), torch.log_softmax(logits[:, -1], -1),
) )
forward_ns = time.time_ns() - start start_decode = time.time_ns()
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
@ -792,7 +792,8 @@ class Seq2SeqLM(Model):
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if stopped: if stopped:
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns) return generations, None, (forward_ns, decode_ns)
# We don't need input_ids after the prefill forward # We don't need input_ids after the prefill forward
@ -804,5 +805,6 @@ class Seq2SeqLM(Model):
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1 batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1
decode_ns = time.time_ns() - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)