Implement top-n-tokens for all models

This commit is contained in:
Vincent Brouwers 2023-07-26 15:12:57 +00:00 committed by Nicolas Patry
parent 38691f8a28
commit 8515999b1d
5 changed files with 110 additions and 82 deletions

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import inspect import inspect
@ -42,6 +43,7 @@ class CausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -72,6 +74,7 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -88,6 +91,7 @@ class CausalLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
@ -138,6 +142,7 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -163,6 +168,7 @@ class CausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0 total_remaining_decode_tokens = 0
new_padding_right_offset = 0 new_padding_right_offset = 0
@ -184,6 +190,7 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -235,6 +242,7 @@ class CausalLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
@ -262,6 +270,7 @@ class CausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
@ -281,6 +290,7 @@ class CausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
@ -438,6 +448,7 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
@ -549,6 +560,10 @@ class CausalLM(Model):
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
)
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
@ -559,6 +574,9 @@ class CausalLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -571,7 +589,19 @@ class CausalLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids.view(1, -1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :] all_input_ids.view(1, -1), logits[-1:, :]

View File

@ -1,6 +1,6 @@
import math import math
import itertools import itertools
from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import torch.distributed import torch.distributed
@ -972,25 +972,14 @@ class FlashCausalLM(Model):
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
top_tokens = [] top_tokens = self.decode_top_tokens(
input_ids=all_input_ids,
if top_n_tokens > 0: top_n_tokens=top_n_tokens,
top_token_texts = self.decode_tokens( top_token_ids=top_token_ids,
input_ids=all_input_ids, top_token_logprobs=top_token_logprobs,
new_input_ids=top_token_ids, prefix_offset=prefix_offset,
prefix_offset=prefix_offset, read_offset=read_offset,
read_offset=read_offset, )
)
for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=top_token_text,
token_is_special=tok_itm in self.all_special_ids,
)
)
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation, TopToken
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
@ -101,8 +101,12 @@ class Model(ABC):
input_ids[prefix_offset:read_offset], skip_special_tokens=False input_ids[prefix_offset:read_offset], skip_special_tokens=False
) )
new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids] new_sequences = [
new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False) input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
]
new_texts = self.tokenizer.batch_decode(
new_sequences, skip_special_tokens=False
)
results = [] results = []
for new_text in new_texts: for new_text in new_texts:
@ -117,6 +121,40 @@ class Model(ABC):
results.append(("", prefix_offset, read_offset)) results.append(("", prefix_offset, read_offset))
return results return results
def decode_top_tokens(
self,
input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
prefix_offset,
read_offset,
):
if top_n_tokens == 0:
return []
top_token_texts = self.decode_tokens(
input_ids=input_ids,
new_input_ids=top_token_ids,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
top_tokens = []
for token_id, (top_token_text, _, _), token_logprob in zip(
top_token_ids, top_token_texts, top_token_logprobs
):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=top_token_text,
token_is_special=tok_itm in self.all_special_ids,
)
)
return top_tokens
def check_initialized(self): def check_initialized(self):
uninitialized_parameters = [] uninitialized_parameters = []
for n, p in self.model.named_parameters(): for n, p in self.model.named_parameters():

View File

@ -1,4 +1,4 @@
from text_generation_server.utils.tokens import get_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
@ -49,6 +49,7 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -79,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = [] decoder_input_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -98,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
@ -147,6 +149,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -174,6 +177,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_input_length = 0 max_input_length = 0
max_decoder_input_length = 0 max_decoder_input_length = 0
@ -205,6 +209,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -255,6 +260,7 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset self.padding_right_offset = padding_right_offset
@ -290,6 +296,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
@ -313,6 +320,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
@ -489,6 +497,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -614,6 +623,10 @@ class Seq2SeqLM(Model):
batch.past_key_values, batch.past_key_values,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
)
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
@ -629,6 +642,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_decoder_input_ids, batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -642,22 +658,24 @@ class Seq2SeqLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_decoder_input_ids, all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_decoder_input_ids.view(1, -1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
all_decoder_input_ids.view(1, -1), logits[-1:, :] all_decoder_input_ids.view(1, -1), logits[-1:, :]
) )
top_tokens = get_top_tokens(
request.top_n_tokens,
logprobs,
self.all_special_ids,
self.decode_token,
all_decoder_input_ids,
prefix_offset,
read_offset,
)
# Append next token to decoder tokens # Append next token to decoder tokens
all_decoder_input_ids = torch.cat( all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)] [all_decoder_input_ids, next_token_id.squeeze(1)]

View File

@ -8,7 +8,6 @@ from transformers import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.types import TopToken
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
@ -361,49 +360,3 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)], [idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
[vals[:n] for vals, n in zip(top_values, top_n_tokens)], [vals[:n] for vals, n in zip(top_values, top_n_tokens)],
) )
def get_top_tokens(
requested_n: int,
logprobs,
special_tokens: List[int],
decode_fn: Callable[[List[int], int, int], str],
decoder_input_ids: List[int],
prefix_offset: int,
read_offset: int,
) -> List[TopToken]:
if not requested_n:
return []
# Dirty hack
flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1]
# Ensure top_n doesn't exceed vocab size
top_n = min(requested_n, flat_scores.size(-1))
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
if nth_highest == -float("inf"):
nth_highest = torch.finfo(flat_scores.dtype).min
# Get indices (token ids) of all scores >= nth highest value,
# cap length at 4 * top_n as a precaution
top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)]
top_tokens = []
for tid_tensor in top_n_indices:
tid_item = tid_tensor[0].item()
token_text, _, _ = decode_fn(
torch.cat([decoder_input_ids, tid_tensor])
if isinstance(decoder_input_ids, torch.Tensor)
else decoder_input_ids + [tid_item],
prefix_offset,
read_offset,
)
top_tokens.append(
TopToken(
token_id=tid_item,
token_logprob=flat_scores[tid_tensor],
token_text=token_text,
token_is_special=tid_item in special_tokens,
)
)
top_tokens.sort(reverse=True)
return top_tokens