diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 930082cd..b826a46b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -842,6 +842,8 @@ class FlashCausalLM(Model): else: next_token_logits = out + + speculate = get_speculate() ( next_input_ids, next_token_logprobs, @@ -851,16 +853,16 @@ class FlashCausalLM(Model): ) = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, - get_speculate(), + speculate, batch.speculative_ids, speculative_logits, ) + speculated_length = batch.speculative_ids.shape[-1] if batch.speculative_ids is not None else 0 batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs + batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids, speculated_length ) - speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -1062,20 +1064,24 @@ class FlashCausalLM(Model): prefill_tokens = None if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) + all_top_tokens = [] + for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens else: top_tokens = None diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index f85f27e5..bc68812e 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -95,5 +95,5 @@ class Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, - top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, + top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 04cc8d97..7f5555bb 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser: scores[:, j] = _scores next_ids[:, j] = _next_ids next_ids = next_ids.view(B * S) - scores = scores.view(B * S, -1) + allscores = scores.view(B * S, -1) + alllogprobs = torch.log_softmax(allscores, -1) if speculated_ids is not None: accepted_ids = [] @@ -305,15 +306,14 @@ class HeterogeneousNextTokenChooser: accepted_ids, device=input_ids.device, dtype=input_ids.dtype ) next_ids = next_ids[indices] - scores = scores[indices] indices = torch.arange(B, device=input_ids.device) * S if speculative_scores is not None: speculative_scores = speculative_scores[indices + accepted_ids - 1] else: accepted_ids = torch.ones_like(next_ids) - logprobs = torch.log_softmax(scores, -1) - next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) + next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1) + if speculate > 0: if speculative_scores is not None: @@ -327,7 +327,7 @@ class HeterogeneousNextTokenChooser: else: speculative_ids = None - return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids + return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids def filter(self, indices): if self.watermark_processor is not None: @@ -436,7 +436,7 @@ class HeterogeneousSampling: def batch_top_tokens( - top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor + top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor, speculative_length: int ) -> Tuple[List[List[int]], List[List[float]]]: """Find the top n most likely tokens for a batch of generations. @@ -448,12 +448,16 @@ def batch_top_tokens( if max_top_n == 0: return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) + + n = speculative_length + 1 + top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(n) # Ensure top_n doesn't exceed vocab size - top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens] + top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculative_length + 1)] # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Sorted topk is faster than torch.sort() since we only need a small subset - sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values + sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values + nth_highest = torch.gather( sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1) ) @@ -471,13 +475,30 @@ def batch_top_tokens( top_indices = top_k.indices.tolist() top_values = top_k.values.tolist() - return ( - [ - idxs[:n] if req_n > 0 else [] - for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens) - ], - [ - vals[:n] if req_n > 0 else [] - for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens) - ], - ) + batch_top_token_ids = [] + batch_top_token_logprobs = [] + accepted_ids = accepted_ids.tolist() + for i, n_accepted_ids in enumerate(accepted_ids): + _top_indices = top_indices[n * i: n * (i + 1)] + _top_values = top_values[n * i: n * (i + 1)] + _top_n_ishes = top_n_ishes[n * i: n * (i + 1)] + _top_n_tokens = top_n_tokens[n * i: n * (i + 1)] + _top_indices = _top_indices[:n_accepted_ids] + _top_values = _top_values[:n_accepted_ids] + _top_n_ishes = _top_n_ishes[:n_accepted_ids] + _top_n_tokens = _top_n_tokens[:n_accepted_ids] + + row_top_token_ids = [] + row_top_token_logprobs = [] + + for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens): + indices = idxs[:n] if req_n > 0 else [] + values = vals[:n] if req_n > 0 else [] + + row_top_token_ids.append(indices) + row_top_token_logprobs.append(values) + + batch_top_token_ids.append(row_top_token_ids) + batch_top_token_logprobs.append(row_top_token_logprobs) + + return batch_top_token_ids, batch_top_token_logprobs