mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Implement top-n-tokens for all models
This commit is contained in:
parent
38691f8a28
commit
8515999b1d
@ -1,3 +1,4 @@
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
@ -42,6 +43,7 @@ class CausalLMBatch(Batch):
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
@ -72,6 +74,7 @@ class CausalLMBatch(Batch):
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
@ -88,6 +91,7 @@ class CausalLMBatch(Batch):
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
@ -138,6 +142,7 @@ class CausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
@ -163,6 +168,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
@ -184,6 +190,7 @@ class CausalLMBatch(Batch):
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
@ -235,6 +242,7 @@ class CausalLMBatch(Batch):
|
||||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
@ -262,6 +270,7 @@ class CausalLMBatch(Batch):
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
@ -281,6 +290,7 @@ class CausalLMBatch(Batch):
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
@ -438,6 +448,7 @@ class CausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
@ -549,6 +560,10 @@ class CausalLM(Model):
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
||||
)
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
@ -559,6 +574,9 @@ class CausalLM(Model):
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
@ -571,7 +589,19 @@ class CausalLM(Model):
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids.view(1, -1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
|
@ -1,6 +1,6 @@
|
||||
import math
|
||||
import itertools
|
||||
from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -972,25 +972,14 @@ class FlashCausalLM(Model):
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = []
|
||||
|
||||
if top_n_tokens > 0:
|
||||
top_token_texts = self.decode_tokens(
|
||||
input_ids=all_input_ids,
|
||||
new_input_ids=top_token_ids,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs):
|
||||
tok_itm = token_id
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=token_id,
|
||||
token_logprob=token_logprob,
|
||||
token_text=top_token_text,
|
||||
token_is_special=tok_itm in self.all_special_ids,
|
||||
)
|
||||
)
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.models.types import Batch, Generation, TopToken
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
@ -101,8 +101,12 @@ class Model(ABC):
|
||||
input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||
)
|
||||
|
||||
new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids]
|
||||
new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False)
|
||||
new_sequences = [
|
||||
input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
|
||||
]
|
||||
new_texts = self.tokenizer.batch_decode(
|
||||
new_sequences, skip_special_tokens=False
|
||||
)
|
||||
|
||||
results = []
|
||||
for new_text in new_texts:
|
||||
@ -117,6 +121,40 @@ class Model(ABC):
|
||||
results.append(("", prefix_offset, read_offset))
|
||||
return results
|
||||
|
||||
def decode_top_tokens(
|
||||
self,
|
||||
input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
):
|
||||
if top_n_tokens == 0:
|
||||
return []
|
||||
|
||||
top_token_texts = self.decode_tokens(
|
||||
input_ids=input_ids,
|
||||
new_input_ids=top_token_ids,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
top_tokens = []
|
||||
for token_id, (top_token_text, _, _), token_logprob in zip(
|
||||
top_token_ids, top_token_texts, top_token_logprobs
|
||||
):
|
||||
tok_itm = token_id
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=token_id,
|
||||
token_logprob=token_logprob,
|
||||
token_text=top_token_text,
|
||||
token_is_special=tok_itm in self.all_special_ids,
|
||||
)
|
||||
)
|
||||
return top_tokens
|
||||
|
||||
def check_initialized(self):
|
||||
uninitialized_parameters = []
|
||||
for n, p in self.model.named_parameters():
|
||||
|
@ -1,4 +1,4 @@
|
||||
from text_generation_server.utils.tokens import get_top_tokens
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -49,6 +49,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
@ -79,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
top_n_tokens = []
|
||||
decoder_input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
@ -98,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
@ -147,6 +149,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
max_input_length=max_input_length.item(),
|
||||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
@ -174,6 +177,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
max_input_length = 0
|
||||
max_decoder_input_length = 0
|
||||
@ -205,6 +209,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
@ -255,6 +260,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.max_input_length = max_input_length
|
||||
self.max_decoder_input_length = max_decoder_input_length
|
||||
self.padding_right_offset = padding_right_offset
|
||||
@ -290,6 +296,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
@ -313,6 +320,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
@ -489,6 +497,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
@ -614,6 +623,10 @@ class Seq2SeqLM(Model):
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
||||
)
|
||||
|
||||
# Finished requests
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
@ -629,6 +642,9 @@ class Seq2SeqLM(Model):
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_decoder_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
@ -642,22 +658,24 @@ class Seq2SeqLM(Model):
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_decoder_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_decoder_input_ids.view(1, -1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
top_tokens = get_top_tokens(
|
||||
request.top_n_tokens,
|
||||
logprobs,
|
||||
self.all_special_ids,
|
||||
self.decode_token,
|
||||
all_decoder_input_ids,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
all_decoder_input_ids = torch.cat(
|
||||
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
||||
|
@ -8,7 +8,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.types import TopToken
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from text_generation_server.utils.logits_process import (
|
||||
@ -361,49 +360,3 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
||||
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
|
||||
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
|
||||
)
|
||||
|
||||
|
||||
def get_top_tokens(
|
||||
requested_n: int,
|
||||
logprobs,
|
||||
special_tokens: List[int],
|
||||
decode_fn: Callable[[List[int], int, int], str],
|
||||
decoder_input_ids: List[int],
|
||||
prefix_offset: int,
|
||||
read_offset: int,
|
||||
) -> List[TopToken]:
|
||||
if not requested_n:
|
||||
return []
|
||||
|
||||
# Dirty hack
|
||||
flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1]
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n = min(requested_n, flat_scores.size(-1))
|
||||
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
|
||||
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
|
||||
if nth_highest == -float("inf"):
|
||||
nth_highest = torch.finfo(flat_scores.dtype).min
|
||||
# Get indices (token ids) of all scores >= nth highest value,
|
||||
# cap length at 4 * top_n as a precaution
|
||||
top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)]
|
||||
top_tokens = []
|
||||
for tid_tensor in top_n_indices:
|
||||
tid_item = tid_tensor[0].item()
|
||||
token_text, _, _ = decode_fn(
|
||||
torch.cat([decoder_input_ids, tid_tensor])
|
||||
if isinstance(decoder_input_ids, torch.Tensor)
|
||||
else decoder_input_ids + [tid_item],
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=tid_item,
|
||||
token_logprob=flat_scores[tid_tensor],
|
||||
token_text=token_text,
|
||||
token_is_special=tid_item in special_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
top_tokens.sort(reverse=True)
|
||||
return top_tokens
|
||||
|
Loading…
Reference in New Issue
Block a user