mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Implement top-n-tokens for all models
This commit is contained in:
parent
38691f8a28
commit
8515999b1d
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ class CausalLMBatch(Batch):
|
|||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
@ -72,6 +74,7 @@ class CausalLMBatch(Batch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -88,6 +91,7 @@ class CausalLMBatch(Batch):
|
|||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
@ -138,6 +142,7 @@ class CausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -163,6 +168,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
total_remaining_decode_tokens = 0
|
total_remaining_decode_tokens = 0
|
||||||
new_padding_right_offset = 0
|
new_padding_right_offset = 0
|
||||||
@ -184,6 +190,7 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
remaining_decode_tokens = (
|
remaining_decode_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
@ -235,6 +242,7 @@ class CausalLMBatch(Batch):
|
|||||||
self.read_offsets = read_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
|
self.top_n_tokens = top_n_tokens
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
@ -262,6 +270,7 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
|
||||||
# Batch tensors
|
# Batch tensors
|
||||||
@ -281,6 +290,7 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
requests_idx_mapping = batch.requests_idx_mapping
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
@ -438,6 +448,7 @@ class CausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
@ -549,6 +560,10 @@ class CausalLM(Model):
|
|||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
||||||
|
)
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
@ -559,6 +574,9 @@ class CausalLM(Model):
|
|||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
@ -571,7 +589,19 @@ class CausalLM(Model):
|
|||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
top_n_tokens,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
|
top_tokens = self.decode_top_tokens(
|
||||||
|
input_ids=all_input_ids.view(1, -1).tolist(),
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_token_ids=top_token_ids,
|
||||||
|
top_token_logprobs=top_token_logprobs,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
)
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_input_ids.view(1, -1), logits[-1:, :]
|
all_input_ids.view(1, -1), logits[-1:, :]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import itertools
|
import itertools
|
||||||
from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -972,25 +972,14 @@ class FlashCausalLM(Model):
|
|||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
top_tokens = []
|
top_tokens = self.decode_top_tokens(
|
||||||
|
input_ids=all_input_ids,
|
||||||
if top_n_tokens > 0:
|
top_n_tokens=top_n_tokens,
|
||||||
top_token_texts = self.decode_tokens(
|
top_token_ids=top_token_ids,
|
||||||
input_ids=all_input_ids,
|
top_token_logprobs=top_token_logprobs,
|
||||||
new_input_ids=top_token_ids,
|
prefix_offset=prefix_offset,
|
||||||
prefix_offset=prefix_offset,
|
read_offset=read_offset,
|
||||||
read_offset=read_offset,
|
)
|
||||||
)
|
|
||||||
for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs):
|
|
||||||
tok_itm = token_id
|
|
||||||
top_tokens.append(
|
|
||||||
TopToken(
|
|
||||||
token_id=token_id,
|
|
||||||
token_logprob=token_logprob,
|
|
||||||
token_text=top_token_text,
|
|
||||||
token_is_special=tok_itm in self.all_special_ids,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation, TopToken
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
@ -101,8 +101,12 @@ class Model(ABC):
|
|||||||
input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||||
)
|
)
|
||||||
|
|
||||||
new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids]
|
new_sequences = [
|
||||||
new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False)
|
input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
|
||||||
|
]
|
||||||
|
new_texts = self.tokenizer.batch_decode(
|
||||||
|
new_sequences, skip_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for new_text in new_texts:
|
for new_text in new_texts:
|
||||||
@ -117,6 +121,40 @@ class Model(ABC):
|
|||||||
results.append(("", prefix_offset, read_offset))
|
results.append(("", prefix_offset, read_offset))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def decode_top_tokens(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
top_n_tokens,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
prefix_offset,
|
||||||
|
read_offset,
|
||||||
|
):
|
||||||
|
if top_n_tokens == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
top_token_texts = self.decode_tokens(
|
||||||
|
input_ids=input_ids,
|
||||||
|
new_input_ids=top_token_ids,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
top_tokens = []
|
||||||
|
for token_id, (top_token_text, _, _), token_logprob in zip(
|
||||||
|
top_token_ids, top_token_texts, top_token_logprobs
|
||||||
|
):
|
||||||
|
tok_itm = token_id
|
||||||
|
top_tokens.append(
|
||||||
|
TopToken(
|
||||||
|
token_id=token_id,
|
||||||
|
token_logprob=token_logprob,
|
||||||
|
token_text=top_token_text,
|
||||||
|
token_is_special=tok_itm in self.all_special_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return top_tokens
|
||||||
|
|
||||||
def check_initialized(self):
|
def check_initialized(self):
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
for n, p in self.model.named_parameters():
|
for n, p in self.model.named_parameters():
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from text_generation_server.utils.tokens import get_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -49,6 +49,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
@ -79,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
@ -98,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
@ -147,6 +149,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
max_decoder_input_length=1,
|
max_decoder_input_length=1,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
@ -174,6 +177,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
max_decoder_input_length = 0
|
max_decoder_input_length = 0
|
||||||
@ -205,6 +209,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
remaining_decode_tokens = (
|
remaining_decode_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
@ -255,6 +260,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
self.read_offsets = read_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
|
self.top_n_tokens = top_n_tokens
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.max_decoder_input_length = max_decoder_input_length
|
self.max_decoder_input_length = max_decoder_input_length
|
||||||
self.padding_right_offset = padding_right_offset
|
self.padding_right_offset = padding_right_offset
|
||||||
@ -290,6 +296,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
|
||||||
# Batch tensors
|
# Batch tensors
|
||||||
@ -313,6 +320,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
requests_idx_mapping = batch.requests_idx_mapping
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
@ -489,6 +497,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
@ -614,6 +623,10 @@ class Seq2SeqLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
||||||
|
)
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
@ -629,6 +642,9 @@ class Seq2SeqLM(Model):
|
|||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_decoder_input_ids,
|
batch.all_decoder_input_ids,
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
@ -642,22 +658,24 @@ class Seq2SeqLM(Model):
|
|||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_decoder_input_ids,
|
all_decoder_input_ids,
|
||||||
|
top_n_tokens,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
|
top_tokens = self.decode_top_tokens(
|
||||||
|
input_ids=all_decoder_input_ids.view(1, -1).tolist(),
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_token_ids=top_token_ids,
|
||||||
|
top_token_logprobs=top_token_logprobs,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
)
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
top_tokens = get_top_tokens(
|
|
||||||
request.top_n_tokens,
|
|
||||||
logprobs,
|
|
||||||
self.all_special_ids,
|
|
||||||
self.decode_token,
|
|
||||||
all_decoder_input_ids,
|
|
||||||
prefix_offset,
|
|
||||||
read_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Append next token to decoder tokens
|
||||||
all_decoder_input_ids = torch.cat(
|
all_decoder_input_ids = torch.cat(
|
||||||
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
||||||
|
@ -8,7 +8,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.types import TopToken
|
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
@ -361,49 +360,3 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor):
|
|||||||
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
|
[idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)],
|
||||||
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
|
[vals[:n] for vals, n in zip(top_values, top_n_tokens)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_top_tokens(
|
|
||||||
requested_n: int,
|
|
||||||
logprobs,
|
|
||||||
special_tokens: List[int],
|
|
||||||
decode_fn: Callable[[List[int], int, int], str],
|
|
||||||
decoder_input_ids: List[int],
|
|
||||||
prefix_offset: int,
|
|
||||||
read_offset: int,
|
|
||||||
) -> List[TopToken]:
|
|
||||||
if not requested_n:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Dirty hack
|
|
||||||
flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1]
|
|
||||||
# Ensure top_n doesn't exceed vocab size
|
|
||||||
top_n = min(requested_n, flat_scores.size(-1))
|
|
||||||
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
|
|
||||||
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
|
|
||||||
if nth_highest == -float("inf"):
|
|
||||||
nth_highest = torch.finfo(flat_scores.dtype).min
|
|
||||||
# Get indices (token ids) of all scores >= nth highest value,
|
|
||||||
# cap length at 4 * top_n as a precaution
|
|
||||||
top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)]
|
|
||||||
top_tokens = []
|
|
||||||
for tid_tensor in top_n_indices:
|
|
||||||
tid_item = tid_tensor[0].item()
|
|
||||||
token_text, _, _ = decode_fn(
|
|
||||||
torch.cat([decoder_input_ids, tid_tensor])
|
|
||||||
if isinstance(decoder_input_ids, torch.Tensor)
|
|
||||||
else decoder_input_ids + [tid_item],
|
|
||||||
prefix_offset,
|
|
||||||
read_offset,
|
|
||||||
)
|
|
||||||
top_tokens.append(
|
|
||||||
TopToken(
|
|
||||||
token_id=tid_item,
|
|
||||||
token_logprob=flat_scores[tid_tensor],
|
|
||||||
token_text=token_text,
|
|
||||||
token_is_special=tid_item in special_tokens,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
top_tokens.sort(reverse=True)
|
|
||||||
return top_tokens
|
|
||||||
|
Loading…
Reference in New Issue
Block a user