Defer building top-token objects to Rust

This commit is contained in:
Vincent Brouwers 2023-08-01 15:02:30 +00:00
parent 730d86f1d0
commit 8471e1862d
7 changed files with 115 additions and 147 deletions

View File

@ -143,6 +143,17 @@ message PrefillTokens {
repeated string texts = 3; repeated string texts = 3;
} }
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
}
message TopToken { message TopToken {
/// Token ID /// Token ID
uint32 token_id = 3; uint32 token_id = 3;
@ -170,7 +181,8 @@ message Generation {
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 7; optional GeneratedText generated_text = 7;
/// Top tokens /// Top tokens
repeated TopToken top_tokens = 8; // repeated TopToken top_tokens = 8;
TopTokens top_tokens = 8;
} }
message FilterBatchRequest { message FilterBatchRequest {

View File

@ -167,10 +167,7 @@ impl Infer {
.collect(); .collect();
} }
// Push last token // Push last token
InferStreamResponse::Intermediate{ InferStreamResponse::Intermediate { token, top_tokens } => {
token,
top_tokens,
} => {
result_tokens.push(token); result_tokens.push(token);
result_top_tokens.push(top_tokens); result_top_tokens.push(top_tokens);
} }
@ -182,7 +179,6 @@ impl Infer {
start, start,
queued, queued,
top_tokens, top_tokens,
} => { } => {
result_tokens.push(token); result_tokens.push(token);
result_top_tokens.push(top_tokens); result_top_tokens.push(top_tokens);
@ -203,7 +199,11 @@ impl Infer {
generated_text, generated_text,
queued, queued,
start, start,
top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() }, top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
@ -533,16 +533,24 @@ fn send_responses(
special: generation.token_is_special, special: generation.token_is_special,
}; };
// generation.top_tokens // generation.top_tokens
let mut top_tokens = Vec::new(); let mut top_tokens = Vec::new();
for top_token in generation.top_tokens { if let Some(top_tokens_) = generation.top_tokens {
top_tokens.push(Token{ top_tokens.extend(
id: top_token.token_id, top_tokens_
text: top_token.token_text, .ids
logprob: top_token.token_logprob, .into_iter()
special: top_token.token_is_special, .zip(top_tokens_.logprobs.into_iter())
}) .zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
})
)
} }
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
@ -562,7 +570,7 @@ fn send_responses(
} else { } else {
// Send message // Send message
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Intermediate{token, top_tokens}), Ok(InferStreamResponse::Intermediate { token, top_tokens }),
Duration::from_millis(10), Duration::from_millis(10),
)?; )?;
} }

View File

@ -13,6 +13,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -126,7 +127,9 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -574,7 +577,9 @@ class CausalLM(Model):
stopped = True stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1) batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
) )
# Zipped iterator # Zipped iterator
@ -652,8 +657,7 @@ class CausalLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
prefill = stopping_criteria.current_tokens == 1 if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
if prefill and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax( prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1 logits, -1
@ -672,15 +676,20 @@ class CausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
# Todo: Make optional for prefill if top_n_tokens > 0:
if not prefill and top_n_tokens > 0: toptoken_texts = self.tokenizer.batch_decode(
top_tokens = self.decode_top_tokens( top_token_ids,
input_ids=all_input_ids[:-1].view(-1).tolist(), clean_up_tokenization_spaces=False,
top_n_tokens=top_n_tokens, skip_special_tokens=False,
top_token_ids=top_token_ids, )
top_token_logprobs=top_token_logprobs, special_toptokens = [
prefix_offset=prefix_offset, token_id in self.all_special_ids for token_id in top_token_ids
read_offset=read_offset, ]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
) )
else: else:
top_tokens = None top_tokens = None

View File

@ -17,7 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopToken, TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -358,7 +358,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor( prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device prefill_next_token_indices, dtype=torch.int64, device=device
) )
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -1039,15 +1041,20 @@ class FlashCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
# Todo: Make optional for prefill if top_n_tokens > 0:
if not prefill and top_n_tokens > 0: toptoken_texts = self.tokenizer.batch_decode(
top_tokens = self.decode_top_tokens( top_token_ids,
input_ids=all_input_ids[:-1], clean_up_tokenization_spaces=False,
top_n_tokens=top_n_tokens, skip_special_tokens=False,
top_token_ids=top_token_ids, )
top_token_logprobs=top_token_logprobs, special_toptokens = [
prefix_offset=prefix_offset, token_id in self.all_special_ids for token_id in top_token_ids
read_offset=read_offset, ]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
) )
else: else:
top_tokens = None top_tokens = None

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText, TopToken from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
@ -86,76 +86,6 @@ class Model(ABC):
else: else:
return "", prefix_offset, read_offset return "", prefix_offset, read_offset
def decode_tokens(
self,
input_ids: List[int],
new_input_ids: List[int],
prefix_offset: int = 0,
read_offset: int = 0,
) -> Tuple[str, int, int]:
"""Version of decode_token that supports multiple new tokens for the same prefix."""
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
input_ids[prefix_offset:read_offset], skip_special_tokens=False
)
new_sequences = [
input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
]
new_texts = self.tokenizer.batch_decode(
new_sequences, skip_special_tokens=False
)
prefix_len = len(prefix_text)
results = []
for new_text in new_texts:
if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[prefix_len:]
results.append((new_text, read_offset, len(input_ids) + 1))
else:
results.append(("", prefix_offset, read_offset))
return results
def decode_top_tokens(
self,
input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
prefix_offset,
read_offset,
):
if top_n_tokens == 0:
return []
top_token_texts = self.decode_tokens(
input_ids=input_ids,
new_input_ids=top_token_ids,
prefix_offset=prefix_offset,
read_offset=read_offset,
)
top_tokens = []
for token_id, (top_token_text, _, _), token_logprob in zip(
top_token_ids, top_token_texts, top_token_logprobs
):
tok_itm = token_id
top_tokens.append(
TopToken(
token_id=token_id,
token_logprob=token_logprob,
token_text=top_token_text,
token_is_special=tok_itm in self.all_special_ids,
)
)
return top_tokens
def check_initialized(self): def check_initialized(self):
uninitialized_parameters = [] uninitialized_parameters = []
for n, p in self.model.named_parameters(): for n, p in self.model.named_parameters():

View File

@ -12,6 +12,7 @@ from text_generation_server.models.types import (
Batch, Batch,
Generation, Generation,
PrefillTokens, PrefillTokens,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -130,7 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0) prefix_offsets.append(0)
read_offsets.append(1) read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -637,7 +640,9 @@ class Seq2SeqLM(Model):
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1) batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
) )
# Finished requests # Finished requests
@ -722,8 +727,7 @@ class Seq2SeqLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
prefill = stopping_criteria.current_tokens == 1 if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
if prefill and request.prefill_logprobs:
prefill_tokens = PrefillTokens( prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],
@ -732,15 +736,20 @@ class Seq2SeqLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
# Todo: Make optional for prefill. How to implement in API? if top_n_tokens > 0:
if not prefill and top_n_tokens > 0: toptoken_texts = self.tokenizer.batch_decode(
top_tokens = self.decode_top_tokens( top_token_ids,
input_ids=all_decoder_input_ids[:-1].view(-1).tolist(), clean_up_tokenization_spaces=False,
top_n_tokens=top_n_tokens, skip_special_tokens=False,
top_token_ids=top_token_ids, )
top_token_logprobs=top_token_logprobs, special_toptokens = [
prefix_offset=prefix_offset, token_id in self.all_special_ids for token_id in top_token_ids
read_offset=read_offset, ]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
) )
else: else:
top_tokens = None top_tokens = None

View File

@ -72,28 +72,23 @@ class PrefillTokens:
return len(self.token_ids) return len(self.token_ids)
@dataclass(eq=True) @dataclass
@total_ordering class TopTokens:
class TopToken: token_ids: List[int]
token_id: int logprobs: List[float]
token_logprob: float texts: List[str]
token_text: str is_special: List[bool]
token_is_special: bool
def __gt__(self, other): def to_pb(self) -> generate_pb2.TopTokens:
# We tiebreak equal logprobs with the _lower_ token_id to align with return generate_pb2.TopTokens(
# greedy ordering (torch.argmax) ids=self.token_ids,
return self.token_logprob > other.token_logprob or ( logprobs=self.logprobs,
self.token_logprob == other.token_logprob and self.token_id < other.token_id texts=self.texts,
is_special=self.is_special,
) )
def to_pb(self) -> generate_pb2.TopToken: def __len__(self):
return generate_pb2.TopToken( return len(self.token_ids)
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
)
@dataclass @dataclass
@ -106,7 +101,7 @@ class Generation:
token_is_special: bool token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model. # Optional for now, since it's not yet supported for every model.
top_tokens: Optional[List[TopToken]] top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
@ -121,7 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens] top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
if self.top_tokens
else None,
) )