mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix decode timing
This commit is contained in:
parent
701dd7da67
commit
248eda7b20
@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
|||||||
|
|
||||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(
|
||||||
default_seq2seq_lm_batch
|
default_seq2seq_lm_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
|
|||||||
):
|
):
|
||||||
next_batch = default_seq2seq_lm_batch
|
next_batch = default_seq2seq_lm_batch
|
||||||
for _ in range(6):
|
for _ in range(6):
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||||||
next_batch = default_multi_requests_seq2seq_lm_batch
|
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||||
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
@ -228,11 +228,11 @@ def test_batch_concatenate(
|
|||||||
default_multi_requests_seq2seq_lm_batch,
|
default_multi_requests_seq2seq_lm_batch,
|
||||||
):
|
):
|
||||||
next_batch_0 = default_seq2seq_lm_batch
|
next_batch_0 = default_seq2seq_lm_batch
|
||||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||||
|
|
||||||
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
||||||
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
|
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
|
||||||
|
|
||||||
# Copy hidden state because it is removed from the concatenated branches
|
# Copy hidden state because it is removed from the concatenated branches
|
||||||
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
|
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
|
||||||
@ -324,10 +324,10 @@ def test_batch_concatenate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 3
|
assert len(generations) == 3
|
||||||
@ -342,7 +342,7 @@ def test_batch_concatenate(
|
|||||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
@ -352,7 +352,7 @@ def test_batch_concatenate(
|
|||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -586,7 +586,7 @@ class CausalLM(Model):
|
|||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
forward_ns = time.time_ns() - start
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -734,7 +734,8 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill
|
||||||
@ -751,5 +752,6 @@ class CausalLM(Model):
|
|||||||
# Update past key values
|
# Update past key values
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
@ -943,7 +943,7 @@ class FlashCausalLM(Model):
|
|||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = next_input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
forward_ns = time.time_ns() - start
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -1105,12 +1105,14 @@ class FlashCausalLM(Model):
|
|||||||
if stopped:
|
if stopped:
|
||||||
del batch
|
del batch
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
|
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
@ -694,7 +694,7 @@ class IdeficsCausalLM(Model):
|
|||||||
# Hardcoded remove image tokens
|
# Hardcoded remove image tokens
|
||||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||||
|
|
||||||
forward_ns = time.time_ns() - start
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -824,7 +824,8 @@ class IdeficsCausalLM(Model):
|
|||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill
|
||||||
@ -845,5 +846,6 @@ class IdeficsCausalLM(Model):
|
|||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
batch.image_hidden_states = image_hidden_states
|
batch.image_hidden_states = image_hidden_states
|
||||||
|
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
@ -646,7 +646,7 @@ class Seq2SeqLM(Model):
|
|||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
forward_ns = time.time_ns() - start
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -792,7 +792,8 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
# We don't need input_ids after the prefill forward
|
# We don't need input_ids after the prefill forward
|
||||||
@ -804,5 +805,6 @@ class Seq2SeqLM(Model):
|
|||||||
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
decode_ns = time.time_ns() - start
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
Loading…
Reference in New Issue
Block a user