diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 373867c7..735ab5eb 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) - generations, next_batch = default_seq2seq_lm.generate_token( + generations, next_batch, _ = default_seq2seq_lm.generate_token( default_seq2seq_lm_batch ) @@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion( ): next_batch = default_seq2seq_lm_batch for _ in range(6): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch = default_multi_requests_seq2seq_lm_batch for i in range(4): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch = next_batch.filter([next_batch.requests[0].id]) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -228,11 +228,11 @@ def test_batch_concatenate( default_multi_requests_seq2seq_lm_batch, ): next_batch_0 = default_seq2seq_lm_batch - _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) - _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0) next_batch_1 = default_multi_requests_seq2seq_lm_batch - _, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1) + _, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1) # Copy hidden state because it is removed from the concatenated branches next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state @@ -324,10 +324,10 @@ def test_batch_concatenate( ) for _ in range(3): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 3 @@ -342,7 +342,7 @@ def test_batch_concatenate( [next_batch.requests[0].id, next_batch.requests[1].id] ) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -352,7 +352,7 @@ def test_batch_concatenate( next_batch = next_batch.filter([next_batch.requests[1].id]) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 86f4461d..7b10256c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -586,7 +586,7 @@ class CausalLM(Model): torch.log_softmax(logits[:, -1], -1), ) - forward_ns = time.time_ns() - start + start_decode = time.time_ns() # Zipped iterator iterator = zip( @@ -734,7 +734,8 @@ class CausalLM(Model): # We finished all generations in the batch; there is no next batch if stopped: - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) # Slice unused values from prefill @@ -751,5 +752,6 @@ class CausalLM(Model): # Update past key values batch.past_key_values = past - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 642a1127..cc0c8a32 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -943,7 +943,7 @@ class FlashCausalLM(Model): # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() - forward_ns = time.time_ns() - start + start_decode = time.time_ns() # Zipped iterator iterator = zip( @@ -1105,12 +1105,14 @@ class FlashCausalLM(Model): if stopped: del batch # No need to return a batch if we know that all requests stopped - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 0e9afeaf..2f28688d 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -694,7 +694,7 @@ class IdeficsCausalLM(Model): # Hardcoded remove image tokens logits[:, 32000:32001] = torch.finfo(logits.dtype).min - forward_ns = time.time_ns() - start + start_decode = time.time_ns() # Results generations: List[Generation] = [] @@ -824,7 +824,8 @@ class IdeficsCausalLM(Model): # We finished all generations in the batch; there is no next batch if stopped: - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) # Slice unused values from prefill @@ -845,5 +846,6 @@ class IdeficsCausalLM(Model): batch.past_key_values = past batch.image_hidden_states = image_hidden_states - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 818a3eb1..f2e4cec6 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -646,7 +646,7 @@ class Seq2SeqLM(Model): torch.log_softmax(logits[:, -1], -1), ) - forward_ns = time.time_ns() - start + start_decode = time.time_ns() # Finished requests generations: List[Generation] = [] @@ -792,7 +792,8 @@ class Seq2SeqLM(Model): # We finished all generations in the batch; there is no next batch if stopped: - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) # We don't need input_ids after the prefill forward @@ -804,5 +805,6 @@ class Seq2SeqLM(Model): batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1 batch.padding_right_offset -= 1 - decode_ns = time.time_ns() - start + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns)