Defer building top-token objects to Rust

This commit is contained in:
Vincent Brouwers 2023-08-01 15:02:30 +00:00 committed by Nicolas Patry
parent 6429695228
commit af0adb8c71
7 changed files with 116 additions and 147 deletions

View File

@ -143,6 +143,17 @@ message PrefillTokens {
repeated string texts = 3;
}
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
}
message TopToken {
/// Token ID
uint32 token_id = 3;
@ -170,7 +181,8 @@ message Generation {
/// Complete generated text
optional GeneratedText generated_text = 7;
/// Top tokens
repeated TopToken top_tokens = 8;
// repeated TopToken top_tokens = 8;
TopTokens top_tokens = 8;
}
message FilterBatchRequest {

View File

@ -167,10 +167,7 @@ impl Infer {
.collect();
}
// Push last token
InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
@ -182,7 +179,6 @@ impl Infer {
start,
queued,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
@ -203,7 +199,11 @@ impl Infer {
generated_text,
queued,
start,
top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() },
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
})
} else {
let err = InferError::IncompleteGeneration;
@ -533,16 +533,24 @@ fn send_responses(
special: generation.token_is_special,
};
// generation.top_tokens
let mut top_tokens = Vec::new();
for top_token in generation.top_tokens {
top_tokens.push(Token{
id: top_token.token_id,
text: top_token.token_text,
logprob: top_token.token_logprob,
special: top_token.token_is_special,
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
})
)
}
if let Some(generated_text) = generation.generated_text {

View File

@ -13,6 +13,7 @@ from text_generation_server.models.types import (
PrefillTokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -126,7 +127,9 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -574,7 +577,9 @@ class CausalLM(Model):
stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Zipped iterator
@ -652,8 +657,7 @@ class CausalLM(Model):
generated_text = None
# Prefill
prefill = stopping_criteria.current_tokens == 1
if prefill and request.prefill_logprobs:
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
@ -672,15 +676,20 @@ class CausalLM(Model):
else:
prefill_tokens = None
# Todo: Make optional for prefill
if not prefill and top_n_tokens > 0:
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids[:-1].view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None

View File

@ -17,7 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens,
Generation,
GeneratedText,
TopToken,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -358,7 +358,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls(
batch_id=pb.id,
@ -1039,15 +1041,20 @@ class FlashCausalLM(Model):
else:
prefill_tokens = None
# Todo: Make optional for prefill
if not prefill and top_n_tokens > 0:
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids[:-1],
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None

View File

@ -5,7 +5,8 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation, TopToken
from text_generation_server.models.types import Batch, Generation
>>>>>>> 8471e18 (Defer building top-token objects to Rust)
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
@ -86,76 +87,6 @@ class Model(ABC):
else:
return "", prefix_offset, read_offset
def decode_tokens(
self,
input_ids: List[int],
new_input_ids: List[int],
prefix_offset: int = 0,
read_offset: int = 0,
) -> Tuple[str, int, int]:
"""Version of decode_token that supports multiple new tokens for the same prefix."""
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
input_ids[prefix_offset:read_offset], skip_special_tokens=False
)
new_sequences = [
input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
]
new_texts = self.tokenizer.batch_decode(
new_sequences, skip_special_tokens=False
)
prefix_len = len(prefix_text)
results = []
for new_text in new_texts:
if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[prefix_len:]
results.append((new_text, read_offset, len(input_ids) + 1))
else:
results.append(("", prefix_offset, read_offset))
return results
def decode_top_tokens(
self,
input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
prefix_offset,
read_offset,
):
if top_n_tokens == 0:
return []
top_token_texts = self.decode_tokens(
input_ids=input_ids,
new_input_ids=top_token_ids,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
top_tokens = []
for token_id, (top_token_text, _, _), token_logprob in zip(
top_token_ids, top_token_texts, top_token_logprobs
):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=top_token_text,
token_is_special=tok_itm in self.all_special_ids,
)
)
return top_tokens
def check_initialized(self):
uninitialized_parameters = []
for n, p in self.model.named_parameters():

View File

@ -12,6 +12,7 @@ from text_generation_server.models.types import (
Batch,
Generation,
PrefillTokens,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -130,7 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -637,7 +640,9 @@ class Seq2SeqLM(Model):
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Finished requests
@ -722,8 +727,7 @@ class Seq2SeqLM(Model):
generated_text = None
# Prefill
prefill = stopping_criteria.current_tokens == 1
if prefill and request.prefill_logprobs:
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id],
[float("nan")],
@ -732,15 +736,20 @@ class Seq2SeqLM(Model):
else:
prefill_tokens = None
# Todo: Make optional for prefill. How to implement in API?
if not prefill and top_n_tokens > 0:
top_tokens = self.decode_top_tokens(
input_ids=all_decoder_input_ids[:-1].view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,
prefix_offset=prefix_offset,
read_offset=read_offset,
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None

View File

@ -72,28 +72,23 @@ class PrefillTokens:
return len(self.token_ids)
@dataclass(eq=True)
@total_ordering
class TopToken:
token_id: int
token_logprob: float
token_text: str
token_is_special: bool
@dataclass
class TopTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def __gt__(self, other):
# We tiebreak equal logprobs with the _lower_ token_id to align with
# greedy ordering (torch.argmax)
return self.token_logprob > other.token_logprob or (
self.token_logprob == other.token_logprob and self.token_id < other.token_id
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def to_pb(self) -> generate_pb2.TopToken:
return generate_pb2.TopToken(
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass
@ -106,7 +101,7 @@ class Generation:
token_is_special: bool
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[List[TopToken]]
top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@ -121,7 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens]
if self.top_tokens
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
)