mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Defer building top-token objects to Rust
This commit is contained in:
parent
730d86f1d0
commit
8471e1862d
@ -143,6 +143,17 @@ message PrefillTokens {
|
|||||||
repeated string texts = 3;
|
repeated string texts = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message TopTokens {
|
||||||
|
/// Top Token IDs
|
||||||
|
repeated uint32 ids = 1;
|
||||||
|
/// Top Logprobs
|
||||||
|
repeated float logprobs = 2;
|
||||||
|
/// Top Token Texts
|
||||||
|
repeated string texts = 3;
|
||||||
|
/// If the tokens are special
|
||||||
|
repeated bool is_special = 6;
|
||||||
|
}
|
||||||
|
|
||||||
message TopToken {
|
message TopToken {
|
||||||
/// Token ID
|
/// Token ID
|
||||||
uint32 token_id = 3;
|
uint32 token_id = 3;
|
||||||
@ -170,7 +181,8 @@ message Generation {
|
|||||||
/// Complete generated text
|
/// Complete generated text
|
||||||
optional GeneratedText generated_text = 7;
|
optional GeneratedText generated_text = 7;
|
||||||
/// Top tokens
|
/// Top tokens
|
||||||
repeated TopToken top_tokens = 8;
|
// repeated TopToken top_tokens = 8;
|
||||||
|
TopTokens top_tokens = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
|
@ -167,10 +167,7 @@ impl Infer {
|
|||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Intermediate{
|
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||||
token,
|
|
||||||
top_tokens,
|
|
||||||
} => {
|
|
||||||
result_tokens.push(token);
|
result_tokens.push(token);
|
||||||
result_top_tokens.push(top_tokens);
|
result_top_tokens.push(top_tokens);
|
||||||
}
|
}
|
||||||
@ -182,7 +179,6 @@ impl Infer {
|
|||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
|
||||||
} => {
|
} => {
|
||||||
result_tokens.push(token);
|
result_tokens.push(token);
|
||||||
result_top_tokens.push(top_tokens);
|
result_top_tokens.push(top_tokens);
|
||||||
@ -203,7 +199,11 @@ impl Infer {
|
|||||||
generated_text,
|
generated_text,
|
||||||
queued,
|
queued,
|
||||||
start,
|
start,
|
||||||
top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() },
|
top_tokens: if use_top_tokens {
|
||||||
|
result_top_tokens
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
@ -533,16 +533,24 @@ fn send_responses(
|
|||||||
special: generation.token_is_special,
|
special: generation.token_is_special,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// generation.top_tokens
|
// generation.top_tokens
|
||||||
|
|
||||||
let mut top_tokens = Vec::new();
|
let mut top_tokens = Vec::new();
|
||||||
for top_token in generation.top_tokens {
|
if let Some(top_tokens_) = generation.top_tokens {
|
||||||
top_tokens.push(Token{
|
top_tokens.extend(
|
||||||
id: top_token.token_id,
|
top_tokens_
|
||||||
text: top_token.token_text,
|
.ids
|
||||||
logprob: top_token.token_logprob,
|
.into_iter()
|
||||||
special: top_token.token_is_special,
|
.zip(top_tokens_.logprobs.into_iter())
|
||||||
|
.zip(top_tokens_.texts.into_iter())
|
||||||
|
.zip(top_tokens_.is_special.into_iter())
|
||||||
|
.map(|(((id, logprob), text), special)| Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
})
|
})
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
@ -562,7 +570,7 @@ fn send_responses(
|
|||||||
} else {
|
} else {
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::Intermediate{token, top_tokens}),
|
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ from text_generation_server.models.types import (
|
|||||||
PrefillTokens,
|
PrefillTokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
@ -126,7 +127,9 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
@ -574,7 +577,9 @@ class CausalLM(Model):
|
|||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
batch.top_n_tokens,
|
||||||
|
batch.top_n_tokens_tensor,
|
||||||
|
torch.softmax(logits[:, -1], -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
@ -652,8 +657,7 @@ class CausalLM(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
prefill = stopping_criteria.current_tokens == 1
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
if prefill and request.prefill_logprobs:
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||||
logits, -1
|
logits, -1
|
||||||
@ -672,15 +676,20 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
# Todo: Make optional for prefill
|
if top_n_tokens > 0:
|
||||||
if not prefill and top_n_tokens > 0:
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
top_tokens = self.decode_top_tokens(
|
top_token_ids,
|
||||||
input_ids=all_input_ids[:-1].view(-1).tolist(),
|
clean_up_tokenization_spaces=False,
|
||||||
top_n_tokens=top_n_tokens,
|
skip_special_tokens=False,
|
||||||
top_token_ids=top_token_ids,
|
)
|
||||||
top_token_logprobs=top_token_logprobs,
|
special_toptokens = [
|
||||||
prefix_offset=prefix_offset,
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
read_offset=read_offset,
|
]
|
||||||
|
top_tokens = TopTokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
@ -17,7 +17,7 @@ from text_generation_server.models.types import (
|
|||||||
PrefillTokens,
|
PrefillTokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
TopToken,
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
@ -358,7 +358,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_next_token_indices = torch.tensor(
|
prefill_next_token_indices = torch.tensor(
|
||||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -1039,15 +1041,20 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
# Todo: Make optional for prefill
|
if top_n_tokens > 0:
|
||||||
if not prefill and top_n_tokens > 0:
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
top_tokens = self.decode_top_tokens(
|
top_token_ids,
|
||||||
input_ids=all_input_ids[:-1],
|
clean_up_tokenization_spaces=False,
|
||||||
top_n_tokens=top_n_tokens,
|
skip_special_tokens=False,
|
||||||
top_token_ids=top_token_ids,
|
)
|
||||||
top_token_logprobs=top_token_logprobs,
|
special_toptokens = [
|
||||||
prefix_offset=prefix_offset,
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
read_offset=read_offset,
|
]
|
||||||
|
top_tokens = TopTokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, GeneratedText, TopToken
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
@ -86,76 +86,6 @@ class Model(ABC):
|
|||||||
else:
|
else:
|
||||||
return "", prefix_offset, read_offset
|
return "", prefix_offset, read_offset
|
||||||
|
|
||||||
def decode_tokens(
|
|
||||||
self,
|
|
||||||
input_ids: List[int],
|
|
||||||
new_input_ids: List[int],
|
|
||||||
prefix_offset: int = 0,
|
|
||||||
read_offset: int = 0,
|
|
||||||
) -> Tuple[str, int, int]:
|
|
||||||
"""Version of decode_token that supports multiple new tokens for the same prefix."""
|
|
||||||
|
|
||||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
|
||||||
# which decide to add a space or not depending on the surrounding ids.
|
|
||||||
prefix_text = self.tokenizer.decode(
|
|
||||||
input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
|
||||||
)
|
|
||||||
|
|
||||||
new_sequences = [
|
|
||||||
input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids
|
|
||||||
]
|
|
||||||
new_texts = self.tokenizer.batch_decode(
|
|
||||||
new_sequences, skip_special_tokens=False
|
|
||||||
)
|
|
||||||
|
|
||||||
prefix_len = len(prefix_text)
|
|
||||||
results = []
|
|
||||||
for new_text in new_texts:
|
|
||||||
if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
|
|
||||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
|
||||||
# from byte fallback tokenization.
|
|
||||||
# If it's in the middle, it's probably a real invalid id generated
|
|
||||||
# by the model
|
|
||||||
new_text = new_text[prefix_len:]
|
|
||||||
results.append((new_text, read_offset, len(input_ids) + 1))
|
|
||||||
else:
|
|
||||||
results.append(("", prefix_offset, read_offset))
|
|
||||||
return results
|
|
||||||
|
|
||||||
def decode_top_tokens(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
top_n_tokens,
|
|
||||||
top_token_ids,
|
|
||||||
top_token_logprobs,
|
|
||||||
prefix_offset,
|
|
||||||
read_offset,
|
|
||||||
):
|
|
||||||
if top_n_tokens == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
top_token_texts = self.decode_tokens(
|
|
||||||
input_ids=input_ids,
|
|
||||||
new_input_ids=top_token_ids,
|
|
||||||
prefix_offset=prefix_offset,
|
|
||||||
read_offset=read_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
top_tokens = []
|
|
||||||
for token_id, (top_token_text, _, _), token_logprob in zip(
|
|
||||||
top_token_ids, top_token_texts, top_token_logprobs
|
|
||||||
):
|
|
||||||
tok_itm = token_id
|
|
||||||
top_tokens.append(
|
|
||||||
TopToken(
|
|
||||||
token_id=token_id,
|
|
||||||
token_logprob=token_logprob,
|
|
||||||
token_text=top_token_text,
|
|
||||||
token_is_special=tok_itm in self.all_special_ids,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return top_tokens
|
|
||||||
|
|
||||||
def check_initialized(self):
|
def check_initialized(self):
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
for n, p in self.model.named_parameters():
|
for n, p in self.model.named_parameters():
|
||||||
|
@ -12,6 +12,7 @@ from text_generation_server.models.types import (
|
|||||||
Batch,
|
Batch,
|
||||||
Generation,
|
Generation,
|
||||||
PrefillTokens,
|
PrefillTokens,
|
||||||
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
@ -130,7 +131,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
prefix_offsets.append(0)
|
prefix_offsets.append(0)
|
||||||
read_offsets.append(1)
|
read_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
@ -637,7 +640,9 @@ class Seq2SeqLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
batch.top_n_tokens,
|
||||||
|
batch.top_n_tokens_tensor,
|
||||||
|
torch.softmax(logits[:, -1], -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
@ -722,8 +727,7 @@ class Seq2SeqLM(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
prefill = stopping_criteria.current_tokens == 1
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
if prefill and request.prefill_logprobs:
|
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(
|
||||||
[self.tokenizer.bos_token_id],
|
[self.tokenizer.bos_token_id],
|
||||||
[float("nan")],
|
[float("nan")],
|
||||||
@ -732,15 +736,20 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
# Todo: Make optional for prefill. How to implement in API?
|
if top_n_tokens > 0:
|
||||||
if not prefill and top_n_tokens > 0:
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
top_tokens = self.decode_top_tokens(
|
top_token_ids,
|
||||||
input_ids=all_decoder_input_ids[:-1].view(-1).tolist(),
|
clean_up_tokenization_spaces=False,
|
||||||
top_n_tokens=top_n_tokens,
|
skip_special_tokens=False,
|
||||||
top_token_ids=top_token_ids,
|
)
|
||||||
top_token_logprobs=top_token_logprobs,
|
special_toptokens = [
|
||||||
prefix_offset=prefix_offset,
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
read_offset=read_offset,
|
]
|
||||||
|
top_tokens = TopTokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
@ -72,28 +72,23 @@ class PrefillTokens:
|
|||||||
return len(self.token_ids)
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=True)
|
@dataclass
|
||||||
@total_ordering
|
class TopTokens:
|
||||||
class TopToken:
|
token_ids: List[int]
|
||||||
token_id: int
|
logprobs: List[float]
|
||||||
token_logprob: float
|
texts: List[str]
|
||||||
token_text: str
|
is_special: List[bool]
|
||||||
token_is_special: bool
|
|
||||||
|
|
||||||
def __gt__(self, other):
|
def to_pb(self) -> generate_pb2.TopTokens:
|
||||||
# We tiebreak equal logprobs with the _lower_ token_id to align with
|
return generate_pb2.TopTokens(
|
||||||
# greedy ordering (torch.argmax)
|
ids=self.token_ids,
|
||||||
return self.token_logprob > other.token_logprob or (
|
logprobs=self.logprobs,
|
||||||
self.token_logprob == other.token_logprob and self.token_id < other.token_id
|
texts=self.texts,
|
||||||
|
is_special=self.is_special,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.TopToken:
|
def __len__(self):
|
||||||
return generate_pb2.TopToken(
|
return len(self.token_ids)
|
||||||
token_id=self.token_id,
|
|
||||||
token_logprob=self.token_logprob,
|
|
||||||
token_text=self.token_text,
|
|
||||||
token_is_special=self.token_is_special,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -106,7 +101,7 @@ class Generation:
|
|||||||
token_is_special: bool
|
token_is_special: bool
|
||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
# Optional for now, since it's not yet supported for every model.
|
# Optional for now, since it's not yet supported for every model.
|
||||||
top_tokens: Optional[List[TopToken]]
|
top_tokens: Optional[TopTokens]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
@ -121,7 +116,5 @@ class Generation:
|
|||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens]
|
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
||||||
if self.top_tokens
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user