diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 0585f1fb..d3f2d766 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -50,19 +50,39 @@ def test_batch_top_tokens(): top_n_tokens = [0, 2, 3, 4, 5] top_n_tokens_tensor = torch.tensor(top_n_tokens) inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5) + accepted_ids = torch.ones_like(top_n_tokens_tensor) topn_tok_ids, topn_tok_logprobs = batch_top_tokens( - top_n_tokens, top_n_tokens_tensor, inp_logprobs + top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids ) - assert topn_tok_ids[0] == [] - assert topn_tok_ids[1] == [0, 3] - assert topn_tok_ids[2] == [0, 3, 1, 4] - assert topn_tok_ids[3] == [0, 3, 1, 4] - assert topn_tok_ids[4] == [0, 3, 1, 4, 2] + assert topn_tok_ids[0] == [[]] + assert topn_tok_ids[1] == [[0, 3]] + assert topn_tok_ids[2] == [[0, 3, 1, 4]] + assert topn_tok_ids[3] == [[0, 3, 1, 4]] + assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]] - assert topn_tok_logprobs[0] == [] - assert topn_tok_logprobs[1] == [-1, -2] - assert topn_tok_logprobs[2] == [-1, -2, -3, -3] - assert topn_tok_logprobs[3] == [-1, -2, -3, -3] - assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4] + assert topn_tok_logprobs[0] == [[]] + assert topn_tok_logprobs[1] == [[-1, -2]] + assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]] + + # Now let's make second member of the batch be speculated + inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2) + accepted_ids[1] = 2 + topn_tok_ids, topn_tok_logprobs = batch_top_tokens( + top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids + ) + + assert topn_tok_ids[0] == [[]] + assert topn_tok_ids[1] == [[0, 3], [0, 3]] + assert topn_tok_ids[2] == [[0, 3, 1, 4]] + assert topn_tok_ids[3] == [[0, 3, 1, 4]] + assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]] + + assert topn_tok_logprobs[0] == [[]] + assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]] + assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7b10256c..a0067992 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -580,10 +580,13 @@ class CausalLM(Model): generations: List[Generation] = [] stopped = True + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, torch.log_softmax(logits[:, -1], -1), + accepted_ids, ) start_decode = time.time_ns() diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b826a46b..53a3d582 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -858,9 +858,8 @@ class FlashCausalLM(Model): speculative_logits, ) - speculated_length = batch.speculative_ids.shape[-1] if batch.speculative_ids is not None else 0 batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids, speculated_length + batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) if prefill: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f2e4cec6..92dddaaf 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -640,10 +640,13 @@ class Seq2SeqLM(Model): batch.past_key_values, ) + # Speculation is not active for seq2seq + accepted_ids = torch.ones_like(batch.input_ids) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, torch.log_softmax(logits[:, -1], -1), + accepted_ids, ) start_decode = time.time_ns() diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 7f5555bb..8761ef3e 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -436,8 +436,8 @@ class HeterogeneousSampling: def batch_top_tokens( - top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor, speculative_length: int -) -> Tuple[List[List[int]], List[List[float]]]: + top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor +) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: """Find the top n most likely tokens for a batch of generations. When multiple tokens have equal probabilities and they don't all fit, the @@ -446,13 +446,14 @@ def batch_top_tokens( max_top_n = max(top_n_tokens) # Early exit when top_n_tokens is not used if max_top_n == 0: - return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) + return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens) - n = speculative_length + 1 - top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(n) + batch_size = accepted_ids.shape[0] + speculate_size = logprobs.shape[0] // batch_size + top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size) # Ensure top_n doesn't exceed vocab size - top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculative_length + 1)] + top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)] # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Sorted topk is faster than torch.sort() since we only need a small subset @@ -477,12 +478,14 @@ def batch_top_tokens( batch_top_token_ids = [] batch_top_token_logprobs = [] - accepted_ids = accepted_ids.tolist() - for i, n_accepted_ids in enumerate(accepted_ids): - _top_indices = top_indices[n * i: n * (i + 1)] - _top_values = top_values[n * i: n * (i + 1)] - _top_n_ishes = top_n_ishes[n * i: n * (i + 1)] - _top_n_tokens = top_n_tokens[n * i: n * (i + 1)] + accepted_ids_list = accepted_ids.tolist() + for i, n_accepted_ids in enumerate(accepted_ids_list): + start = speculate_size * i + stop = speculate_size * (i + 1) + _top_indices = top_indices[start: stop] + _top_values = top_values[start: stop] + _top_n_ishes = top_n_ishes[start: stop] + _top_n_tokens = top_n_tokens[start: stop] _top_indices = _top_indices[:n_accepted_ids] _top_values = _top_values[:n_accepted_ids] _top_n_ishes = _top_n_ishes[:n_accepted_ids]