mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Defer building top-token objects to Rust
This commit is contained in:
parent
730d86f1d0
commit
8471e1862d
@ -143,6 +143,17 @@ message PrefillTokens {
|
||||
repeated string texts = 3;
|
||||
}
|
||||
|
||||
message TopTokens {
|
||||
/// Top Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Top Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// Top Token Texts
|
||||
repeated string texts = 3;
|
||||
/// If the tokens are special
|
||||
repeated bool is_special = 6;
|
||||
}
|
||||
|
||||
message TopToken {
|
||||
/// Token ID
|
||||
uint32 token_id = 3;
|
||||
@ -170,7 +181,8 @@ message Generation {
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 7;
|
||||
/// Top tokens
|
||||
repeated TopToken top_tokens = 8;
|
||||
// repeated TopToken top_tokens = 8;
|
||||
TopTokens top_tokens = 8;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
|
@ -167,10 +167,7 @@ impl Infer {
|
||||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Intermediate{
|
||||
token,
|
||||
top_tokens,
|
||||
} => {
|
||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
}
|
||||
@ -182,7 +179,6 @@ impl Infer {
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
|
||||
} => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
@ -203,7 +199,11 @@ impl Infer {
|
||||
generated_text,
|
||||
queued,
|
||||
start,
|
||||
top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() },
|
||||
top_tokens: if use_top_tokens {
|
||||
result_top_tokens
|
||||
} else {
|
||||
Vec::new()
|
||||
},
|
||||
})
|
||||
} else {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
@ -533,16 +533,24 @@ fn send_responses(
|
||||
special: generation.token_is_special,
|
||||
};
|
||||
|
||||
|
||||
// generation.top_tokens
|
||||
|
||||
let mut top_tokens = Vec::new();
|
||||
for top_token in generation.top_tokens {
|
||||
top_tokens.push(Token{
|
||||
id: top_token.token_id,
|
||||
text: top_token.token_text,
|
||||
logprob: top_token.token_logprob,
|
||||
special: top_token.token_is_special,
|
||||
if let Some(top_tokens_) = generation.top_tokens {
|
||||
top_tokens.extend(
|
||||
top_tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(top_tokens_.logprobs.into_iter())
|
||||
.zip(top_tokens_.texts.into_iter())
|
||||
.zip(top_tokens_.is_special.into_iter())
|
||||
.map(|(((id, logprob), text), special)| Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
@ -562,7 +570,7 @@ fn send_responses(
|
||||
} else {
|
||||
// Send message
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Intermediate{token, top_tokens}),
|
||||
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ from text_generation_server.models.types import (
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
@ -126,7 +127,9 @@ class CausalLMBatch(Batch):
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
@ -574,7 +577,9 @@ class CausalLM(Model):
|
||||
stopped = True
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
# Zipped iterator
|
||||
@ -652,8 +657,7 @@ class CausalLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
prefill = stopping_criteria.current_tokens == 1
|
||||
if prefill and request.prefill_logprobs:
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
@ -672,15 +676,20 @@ class CausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
# Todo: Make optional for prefill
|
||||
if not prefill and top_n_tokens > 0:
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids[:-1].view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
@ -17,7 +17,7 @@ from text_generation_server.models.types import (
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopToken,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
@ -358,7 +358,9 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
@ -1039,15 +1041,20 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
# Todo: Make optional for prefill
|
||||
if not prefill and top_n_tokens > 0:
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids[:-1],
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, GeneratedText, TopToken
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
@ -86,76 +86,6 @@ class Model(ABC):
|
||||
else:
|
||||
return "", prefix_offset, read_offset
|
||||
|
||||
def decode_tokens(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
new_input_ids: List[int],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
) -> Tuple[str, int, int]:
|
||||
"""Version of decode_token that supports multiple new tokens for the same prefix."""
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||
# which decide to add a space or not depending on the surrounding ids.
|
||||
prefix_text = self.tokenizer.decode(
|
||||
input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||
)
|
||||
|
||||
new_sequences = [
|
||||
input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
|
||||
]
|
||||
new_texts = self.tokenizer.batch_decode(
|
||||
new_sequences, skip_special_tokens=False
|
||||
)
|
||||
|
||||
prefix_len = len(prefix_text)
|
||||
results = []
|
||||
for new_text in new_texts:
|
||||
if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
new_text = new_text[prefix_len:]
|
||||
results.append((new_text, read_offset, len(input_ids) + 1))
|
||||
else:
|
||||
results.append(("", prefix_offset, read_offset))
|
||||
return results
|
||||
|
||||
def decode_top_tokens(
|
||||
self,
|
||||
input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
):
|
||||
if top_n_tokens == 0:
|
||||
return []
|
||||
|
||||
top_token_texts = self.decode_tokens(
|
||||
input_ids=input_ids,
|
||||
new_input_ids=top_token_ids,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
)
|
||||
|
||||
top_tokens = []
|
||||
for token_id, (top_token_text, _, _), token_logprob in zip(
|
||||
top_token_ids, top_token_texts, top_token_logprobs
|
||||
):
|
||||
tok_itm = token_id
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=token_id,
|
||||
token_logprob=token_logprob,
|
||||
token_text=top_token_text,
|
||||
token_is_special=tok_itm in self.all_special_ids,
|
||||
)
|
||||
)
|
||||
return top_tokens
|
||||
|
||||
def check_initialized(self):
|
||||
uninitialized_parameters = []
|
||||
for n, p in self.model.named_parameters():
|
||||
|
@ -12,6 +12,7 @@ from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Generation,
|
||||
PrefillTokens,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
@ -130,7 +131,9 @@ class Seq2SeqLMBatch(Batch):
|
||||
prefix_offsets.append(0)
|
||||
read_offsets.append(1)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
@ -637,7 +640,9 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
# Finished requests
|
||||
@ -722,8 +727,7 @@ class Seq2SeqLM(Model):
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
prefill = stopping_criteria.current_tokens == 1
|
||||
if prefill and request.prefill_logprobs:
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
prefill_tokens = PrefillTokens(
|
||||
[self.tokenizer.bos_token_id],
|
||||
[float("nan")],
|
||||
@ -732,15 +736,20 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
# Todo: Make optional for prefill. How to implement in API?
|
||||
if not prefill and top_n_tokens > 0:
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_decoder_input_ids[:-1].view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
@ -72,28 +72,23 @@ class PrefillTokens:
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass(eq=True)
|
||||
@total_ordering
|
||||
class TopToken:
|
||||
token_id: int
|
||||
token_logprob: float
|
||||
token_text: str
|
||||
token_is_special: bool
|
||||
@dataclass
|
||||
class TopTokens:
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
texts: List[str]
|
||||
is_special: List[bool]
|
||||
|
||||
def __gt__(self, other):
|
||||
# We tiebreak equal logprobs with the _lower_ token_id to align with
|
||||
# greedy ordering (torch.argmax)
|
||||
return self.token_logprob > other.token_logprob or (
|
||||
self.token_logprob == other.token_logprob and self.token_id < other.token_id
|
||||
def to_pb(self) -> generate_pb2.TopTokens:
|
||||
return generate_pb2.TopTokens(
|
||||
ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
texts=self.texts,
|
||||
is_special=self.is_special,
|
||||
)
|
||||
|
||||
def to_pb(self) -> generate_pb2.TopToken:
|
||||
return generate_pb2.TopToken(
|
||||
token_id=self.token_id,
|
||||
token_logprob=self.token_logprob,
|
||||
token_text=self.token_text,
|
||||
token_is_special=self.token_is_special,
|
||||
)
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -106,7 +101,7 @@ class Generation:
|
||||
token_is_special: bool
|
||||
generated_text: Optional[GeneratedText]
|
||||
# Optional for now, since it's not yet supported for every model.
|
||||
top_tokens: Optional[List[TopToken]]
|
||||
top_tokens: Optional[TopTokens]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
@ -121,7 +116,5 @@ class Generation:
|
||||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens]
|
||||
if self.top_tokens
|
||||
else None,
|
||||
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user