diff --git a/proto/generate.proto b/proto/generate.proto index 89ca6e08..45ba8da5 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -143,6 +143,17 @@ message PrefillTokens { repeated string texts = 3; } +message TopTokens { + /// Top Token IDs + repeated uint32 ids = 1; + /// Top Logprobs + repeated float logprobs = 2; + /// Top Token Texts + repeated string texts = 3; + /// If the tokens are special + repeated bool is_special = 6; +} + message TopToken { /// Token ID uint32 token_id = 3; @@ -170,7 +181,8 @@ message Generation { /// Complete generated text optional GeneratedText generated_text = 7; /// Top tokens - repeated TopToken top_tokens = 8; + // repeated TopToken top_tokens = 8; + TopTokens top_tokens = 8; } message FilterBatchRequest { diff --git a/router/src/infer.rs b/router/src/infer.rs index b88dfb2a..c723796d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -167,10 +167,7 @@ impl Infer { .collect(); } // Push last token - InferStreamResponse::Intermediate{ - token, - top_tokens, - } => { + InferStreamResponse::Intermediate { token, top_tokens } => { result_tokens.push(token); result_top_tokens.push(top_tokens); } @@ -182,7 +179,6 @@ impl Infer { start, queued, top_tokens, - } => { result_tokens.push(token); result_top_tokens.push(top_tokens); @@ -203,7 +199,11 @@ impl Infer { generated_text, queued, start, - top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() }, + top_tokens: if use_top_tokens { + result_top_tokens + } else { + Vec::new() + }, }) } else { let err = InferError::IncompleteGeneration; @@ -533,16 +533,24 @@ fn send_responses( special: generation.token_is_special, }; - // generation.top_tokens + let mut top_tokens = Vec::new(); - for top_token in generation.top_tokens { - top_tokens.push(Token{ - id: top_token.token_id, - text: top_token.token_text, - logprob: top_token.token_logprob, - special: top_token.token_is_special, - }) + if let Some(top_tokens_) = generation.top_tokens { + top_tokens.extend( + top_tokens_ + .ids + .into_iter() + .zip(top_tokens_.logprobs.into_iter()) + .zip(top_tokens_.texts.into_iter()) + .zip(top_tokens_.is_special.into_iter()) + .map(|(((id, logprob), text), special)| Token { + id, + text, + logprob, + special, + }) + ) } if let Some(generated_text) = generation.generated_text { @@ -562,7 +570,7 @@ fn send_responses( } else { // Send message entry.response_tx.send_timeout( - Ok(InferStreamResponse::Intermediate{token, top_tokens}), + Ok(InferStreamResponse::Intermediate { token, top_tokens }), Duration::from_millis(10), )?; } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 807c39de..4e338263 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -13,6 +13,7 @@ from text_generation_server.models.types import ( PrefillTokens, Generation, GeneratedText, + TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -126,7 +127,9 @@ class CausalLMBatch(Batch): position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) @@ -574,7 +577,9 @@ class CausalLM(Model): stopped = True batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1) + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.softmax(logits[:, -1], -1), ) # Zipped iterator @@ -652,8 +657,7 @@ class CausalLM(Model): generated_text = None # Prefill - prefill = stopping_criteria.current_tokens == 1 - if prefill and request.prefill_logprobs: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 @@ -672,15 +676,20 @@ class CausalLM(Model): else: prefill_tokens = None - # Todo: Make optional for prefill - if not prefill and top_n_tokens > 0: - top_tokens = self.decode_top_tokens( - input_ids=all_input_ids[:-1].view(-1).tolist(), - top_n_tokens=top_n_tokens, - top_token_ids=top_token_ids, - top_token_logprobs=top_token_logprobs, - prefix_offset=prefix_offset, - read_offset=read_offset, + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, ) else: top_tokens = None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3fea3e0e..742e8430 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -17,7 +17,7 @@ from text_generation_server.models.types import ( PrefillTokens, Generation, GeneratedText, - TopToken, + TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -358,7 +358,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) return cls( batch_id=pb.id, @@ -1039,15 +1041,20 @@ class FlashCausalLM(Model): else: prefill_tokens = None - # Todo: Make optional for prefill - if not prefill and top_n_tokens > 0: - top_tokens = self.decode_top_tokens( - input_ids=all_input_ids[:-1], - top_n_tokens=top_n_tokens, - top_token_ids=top_token_ids, - top_token_logprobs=top_token_logprobs, - prefix_offset=prefix_offset, - read_offset=read_offset, + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, ) else: top_tokens = None diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index a0b7e96d..06229f35 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,7 +5,8 @@ from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase, PretrainedConfig -from text_generation_server.models.types import Batch, Generation, TopToken +from text_generation_server.models.types import Batch, Generation +>>>>>>> 8471e18 (Defer building top-token objects to Rust) from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) @@ -86,76 +87,6 @@ class Model(ABC): else: return "", prefix_offset, read_offset - def decode_tokens( - self, - input_ids: List[int], - new_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - """Version of decode_token that supports multiple new tokens for the same prefix.""" - - # The prefix text is necessary only to defeat cleanup algorithms in the decode - # which decide to add a space or not depending on the surrounding ids. - prefix_text = self.tokenizer.decode( - input_ids[prefix_offset:read_offset], skip_special_tokens=False - ) - - new_sequences = [ - input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids - ] - new_texts = self.tokenizer.batch_decode( - new_sequences, skip_special_tokens=False - ) - - prefix_len = len(prefix_text) - results = [] - for new_text in new_texts: - if len(new_text) > prefix_len and not new_text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. - # If it's in the middle, it's probably a real invalid id generated - # by the model - new_text = new_text[prefix_len:] - results.append((new_text, read_offset, len(input_ids) + 1)) - else: - results.append(("", prefix_offset, read_offset)) - return results - - def decode_top_tokens( - self, - input_ids, - top_n_tokens, - top_token_ids, - top_token_logprobs, - prefix_offset, - read_offset, - ): - if top_n_tokens == 0: - return [] - - top_token_texts = self.decode_tokens( - input_ids=input_ids, - new_input_ids=top_token_ids, - prefix_offset=prefix_offset, - read_offset=read_offset, - ) - - top_tokens = [] - for token_id, (top_token_text, _, _), token_logprob in zip( - top_token_ids, top_token_texts, top_token_logprobs - ): - tok_itm = token_id - top_tokens.append( - TopToken( - token_id=token_id, - token_logprob=token_logprob, - token_text=top_token_text, - token_is_special=tok_itm in self.all_special_ids, - ) - ) - return top_tokens - def check_initialized(self): uninitialized_parameters = [] for n, p in self.model.named_parameters(): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 096ab82f..361453fb 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -12,6 +12,7 @@ from text_generation_server.models.types import ( Batch, Generation, PrefillTokens, + TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -130,7 +131,9 @@ class Seq2SeqLMBatch(Batch): prefix_offsets.append(0) read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) @@ -637,7 +640,9 @@ class Seq2SeqLM(Model): ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1) + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.softmax(logits[:, -1], -1), ) # Finished requests @@ -722,8 +727,7 @@ class Seq2SeqLM(Model): generated_text = None # Prefill - prefill = stopping_criteria.current_tokens == 1 - if prefill and request.prefill_logprobs: + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: prefill_tokens = PrefillTokens( [self.tokenizer.bos_token_id], [float("nan")], @@ -732,15 +736,20 @@ class Seq2SeqLM(Model): else: prefill_tokens = None - # Todo: Make optional for prefill. How to implement in API? - if not prefill and top_n_tokens > 0: - top_tokens = self.decode_top_tokens( - input_ids=all_decoder_input_ids[:-1].view(-1).tolist(), - top_n_tokens=top_n_tokens, - top_token_ids=top_token_ids, - top_token_logprobs=top_token_logprobs, - prefix_offset=prefix_offset, - read_offset=read_offset, + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, ) else: top_tokens = None diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 822335ff..0e27680d 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -72,28 +72,23 @@ class PrefillTokens: return len(self.token_ids) -@dataclass(eq=True) -@total_ordering -class TopToken: - token_id: int - token_logprob: float - token_text: str - token_is_special: bool +@dataclass +class TopTokens: + token_ids: List[int] + logprobs: List[float] + texts: List[str] + is_special: List[bool] - def __gt__(self, other): - # We tiebreak equal logprobs with the _lower_ token_id to align with - # greedy ordering (torch.argmax) - return self.token_logprob > other.token_logprob or ( - self.token_logprob == other.token_logprob and self.token_id < other.token_id + def to_pb(self) -> generate_pb2.TopTokens: + return generate_pb2.TopTokens( + ids=self.token_ids, + logprobs=self.logprobs, + texts=self.texts, + is_special=self.is_special, ) - def to_pb(self) -> generate_pb2.TopToken: - return generate_pb2.TopToken( - token_id=self.token_id, - token_logprob=self.token_logprob, - token_text=self.token_text, - token_is_special=self.token_is_special, - ) + def __len__(self): + return len(self.token_ids) @dataclass @@ -106,7 +101,7 @@ class Generation: token_is_special: bool generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. - top_tokens: Optional[List[TopToken]] + top_tokens: Optional[TopTokens] def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -121,7 +116,5 @@ class Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, - top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens] - if self.top_tokens - else None, + top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, )