diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0978fc81..57f95603 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -949,6 +949,7 @@ class FlashCausalLM(Model): batch.all_input_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, + batch.top_n_tokens, next_token_ids, next_token_logprobs, batch_top_token_ids, @@ -965,26 +966,31 @@ class FlashCausalLM(Model): all_input_ids, do_sample, seed, + top_n_tokens, next_token_id, next_token_logprob, top_token_ids, top_token_logprobs, ) in enumerate(iterator): top_tokens = [] - for token_id, token_logprob in zip(top_token_ids, top_token_logprobs): - tok_itm = token_id - top_tokens.append( - TopToken( - token_id=token_id, - token_logprob=token_logprob, - token_text=self.decode_token( - all_input_ids=all_input_ids + [tok_itm], - prefix_offset=prefix_offset, - read_offset=read_offset, - )[0], - token_is_special=tok_itm in self.all_special_ids, - ) + + if top_n_tokens > 0: + top_token_texts = self.decode_tokens( + input_ids=all_input_ids, + new_input_ids=top_token_ids, + prefix_offset=prefix_offset, + read_offset=read_offset, ) + for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs): + tok_itm = token_id + top_tokens.append( + TopToken( + token_id=token_id, + token_logprob=token_logprob, + token_text=top_token_text, + token_is_special=tok_itm in self.all_special_ids, + ) + ) # Append next token to all tokens all_input_ids.append(next_token_id) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 9d74247c..941dc5e1 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -86,6 +86,37 @@ class Model(ABC): else: return "", prefix_offset, read_offset + def decode_tokens( + self, + input_ids: List[int], + new_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: + """Version of decode_token that supports multiple new tokens for the same prefix.""" + + # The prefix text is necessary only to defeat cleanup algorithms in the decode + # which decide to add a space or not depending on the surrounding ids. + prefix_text = self.tokenizer.decode( + input_ids[prefix_offset:read_offset], skip_special_tokens=False + ) + + new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids] + new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False) + + results = [] + for new_text in new_texts: + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + new_text = new_text[len(prefix_text) :] + results.append((new_text, read_offset, len(input_ids) + 1)) + else: + results.append(("", prefix_offset, read_offset)) + return results + def check_initialized(self): uninitialized_parameters = [] for n, p in self.model.named_parameters(): diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0682959b..499ff054 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -345,7 +345,7 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor): """Find the top n most likely tokens for a batch of generations.""" top_n_tokens = torch.tensor(top_n_tokens) if top_n_tokens.min() == 0: - return [], [] + return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) # Ensure top_n doesn't exceed vocab size top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1))