diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a0067992..29e9f8b1 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -581,7 +581,7 @@ class CausalLM(Model): stopped = True # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids) + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, @@ -695,20 +695,24 @@ class CausalLM(Model): prefill_tokens = None if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) + all_top_tokens = [] + for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens else: top_tokens = None diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 143d0a3d..8b93aecd 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -641,7 +641,7 @@ class Seq2SeqLM(Model): ) # Speculation is not active for seq2seq - accepted_ids = torch.ones_like(batch.decoder_input_ids) + accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0] batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, @@ -749,20 +749,24 @@ class Seq2SeqLM(Model): prefill_tokens = None if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) + all_top_tokens = [] + for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens else: top_tokens = None diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 8761ef3e..270a6990 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -306,13 +306,15 @@ class HeterogeneousNextTokenChooser: accepted_ids, device=input_ids.device, dtype=input_ids.dtype ) next_ids = next_ids[indices] + logprobs = alllogprobs[indices] indices = torch.arange(B, device=input_ids.device) * S if speculative_scores is not None: speculative_scores = speculative_scores[indices + accepted_ids - 1] else: accepted_ids = torch.ones_like(next_ids) + logprobs = alllogprobs - next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1) + next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) if speculate > 0: @@ -436,7 +438,7 @@ class HeterogeneousSampling: def batch_top_tokens( - top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor + top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: """Find the top n most likely tokens for a batch of generations. @@ -486,6 +488,7 @@ def batch_top_tokens( _top_values = top_values[start: stop] _top_n_ishes = top_n_ishes[start: stop] _top_n_tokens = top_n_tokens[start: stop] + _top_indices = _top_indices[:n_accepted_ids] _top_values = _top_values[:n_accepted_ids] _top_n_ishes = _top_n_ishes[:n_accepted_ids]