mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix decode timing
This commit is contained in:
parent
701dd7da67
commit
248eda7b20
@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
||||
|
||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(
|
||||
default_seq2seq_lm_batch
|
||||
)
|
||||
|
||||
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
|
||||
):
|
||||
next_batch = default_seq2seq_lm_batch
|
||||
for _ in range(6):
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||
|
||||
for i in range(4):
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
@ -228,11 +228,11 @@ def test_batch_concatenate(
|
||||
default_multi_requests_seq2seq_lm_batch,
|
||||
):
|
||||
next_batch_0 = default_seq2seq_lm_batch
|
||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
|
||||
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
||||
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
|
||||
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
|
||||
|
||||
# Copy hidden state because it is removed from the concatenated branches
|
||||
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
|
||||
@ -324,10 +324,10 @@ def test_batch_concatenate(
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
@ -342,7 +342,7 @@ def test_batch_concatenate(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
@ -352,7 +352,7 @@ def test_batch_concatenate(
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -586,7 +586,7 @@ class CausalLM(Model):
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
forward_ns = time.time_ns() - start
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -734,7 +734,8 @@ class CausalLM(Model):
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
@ -751,5 +752,6 @@ class CausalLM(Model):
|
||||
# Update past key values
|
||||
batch.past_key_values = past
|
||||
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
@ -943,7 +943,7 @@ class FlashCausalLM(Model):
|
||||
# GPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
forward_ns = time.time_ns() - start
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -1105,12 +1105,14 @@ class FlashCausalLM(Model):
|
||||
if stopped:
|
||||
del batch
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
@ -694,7 +694,7 @@ class IdeficsCausalLM(Model):
|
||||
# Hardcoded remove image tokens
|
||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||
|
||||
forward_ns = time.time_ns() - start
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
@ -824,7 +824,8 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
@ -845,5 +846,6 @@ class IdeficsCausalLM(Model):
|
||||
batch.past_key_values = past
|
||||
batch.image_hidden_states = image_hidden_states
|
||||
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
@ -646,7 +646,7 @@ class Seq2SeqLM(Model):
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
forward_ns = time.time_ns() - start
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Finished requests
|
||||
generations: List[Generation] = []
|
||||
@ -792,7 +792,8 @@ class Seq2SeqLM(Model):
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# We don't need input_ids after the prefill forward
|
||||
@ -804,5 +805,6 @@ class Seq2SeqLM(Model):
|
||||
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
decode_ns = time.time_ns() - start
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
Loading…
Reference in New Issue
Block a user