Implement top-n-tokens for all models

This commit is contained in:
Vincent Brouwers 2023-07-26 15:12:57 +00:00
parent 494e6b1c61
commit 50d05fa20d
5 changed files with 110 additions and 82 deletions

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import inspect
@ -42,6 +43,7 @@ class CausalLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
# Metadata used for padding
max_input_length: int
@ -72,6 +74,7 @@ class CausalLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
@ -88,6 +91,7 @@ class CausalLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@ -138,6 +142,7 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
@ -163,6 +168,7 @@ class CausalLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
@ -184,6 +190,7 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -235,6 +242,7 @@ class CausalLMBatch(Batch):
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
@ -262,6 +270,7 @@ class CausalLMBatch(Batch):
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
@ -281,6 +290,7 @@ class CausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@ -438,6 +448,7 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
@ -549,6 +560,10 @@ class CausalLM(Model):
generations: List[Generation] = []
stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
)
# Zipped iterator
iterator = zip(
batch.requests,
@ -559,6 +574,9 @@ class CausalLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -571,7 +589,19 @@ class CausalLM(Model):
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids.view(1, -1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]

View File

@ -1,6 +1,6 @@
import math
import itertools
from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import torch.distributed
@ -972,25 +972,14 @@ class FlashCausalLM(Model):
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
top_tokens = []
if top_n_tokens > 0:
top_token_texts = self.decode_tokens(
input_ids=all_input_ids,
new_input_ids=top_token_ids,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=top_token_text,
token_is_special=tok_itm in self.all_special_ids,
)
)
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids,
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Append next token to all tokens
all_input_ids.append(next_token_id)

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, GeneratedText, TopToken
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
@ -101,8 +101,12 @@ class Model(ABC):
input_ids[prefix_offset:read_offset], skip_special_tokens=False
)
new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids]
new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False)
new_sequences = [
input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
]
new_texts = self.tokenizer.batch_decode(
new_sequences, skip_special_tokens=False
)
results = []
for new_text in new_texts:
@ -117,6 +121,40 @@ class Model(ABC):
results.append(("", prefix_offset, read_offset))
return results
def decode_top_tokens(
self,
input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
prefix_offset,
read_offset,
):
if top_n_tokens == 0:
return []
top_token_texts = self.decode_tokens(
input_ids=input_ids,
new_input_ids=top_token_ids,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
top_tokens = []
for token_id, (top_token_text, _, _), token_logprob in zip(
top_token_ids, top_token_texts, top_token_logprobs
):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=top_token_text,
token_is_special=tok_itm in self.all_special_ids,
)
)
return top_tokens
def check_initialized(self):
uninitialized_parameters = []
for n, p in self.model.named_parameters():

View File

@ -1,4 +1,4 @@
from text_generation_server.utils.tokens import get_top_tokens
from text_generation_server.utils.tokens import batch_top_tokens
import torch
from dataclasses import dataclass
@ -49,6 +49,7 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
# Metadata used for padding
max_input_length: int
@ -79,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = []
prefix_offsets = []
read_offsets = []
@ -98,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@ -147,6 +149,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
@ -174,6 +177,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_input_length = 0
max_decoder_input_length = 0
@ -205,6 +209,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -255,6 +260,7 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
@ -290,6 +296,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
@ -313,6 +320,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@ -489,6 +497,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
@ -614,6 +623,10 @@ class Seq2SeqLM(Model):
batch.past_key_values,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
)
# Finished requests
generations: List[Generation] = []
stopped = True
@ -629,6 +642,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -642,22 +658,24 @@ class Seq2SeqLM(Model):
next_token_chooser,
stopping_criteria,
all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_decoder_input_ids.view(1, -1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Select next token
next_token_id, logprobs = next_token_chooser(
all_decoder_input_ids.view(1, -1), logits[-1:, :]
)
top_tokens = get_top_tokens(
request.top_n_tokens,
logprobs,
self.all_special_ids,
self.decode_token,
all_decoder_input_ids,
prefix_offset,
read_offset,
)
# Append next token to decoder tokens
all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)]

View File

@ -8,7 +8,6 @@ from transformers import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.types import TopToken
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import (
@ -361,49 +360,3 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
)
def get_top_tokens(
requested_n: int,
logprobs,
special_tokens: List[int],
decode_fn: Callable[[List[int], int, int], str],
decoder_input_ids: List[int],
prefix_offset: int,
read_offset: int,
) -> List[TopToken]:
if not requested_n:
return []
# Dirty hack
flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1]
# Ensure top_n doesn't exceed vocab size
top_n = min(requested_n, flat_scores.size(-1))
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
if nth_highest == -float("inf"):
nth_highest = torch.finfo(flat_scores.dtype).min
# Get indices (token ids) of all scores >= nth highest value,
# cap length at 4 * top_n as a precaution
top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)]
top_tokens = []
for tid_tensor in top_n_indices:
tid_item = tid_tensor[0].item()
token_text, _, _ = decode_fn(
torch.cat([decoder_input_ids, tid_tensor])
if isinstance(decoder_input_ids, torch.Tensor)
else decoder_input_ids + [tid_item],
prefix_offset,
read_offset,
)
top_tokens.append(
TopToken(
token_id=tid_item,
token_logprob=flat_scores[tid_tensor],
token_text=token_text,
token_is_special=tid_item in special_tokens,
)
)
top_tokens.sort(reverse=True)
return top_tokens