From 0b96da89aad5a178598b648a8a6b02dad7370bd3 Mon Sep 17 00:00:00 2001 From: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Date: Fri, 19 Jan 2024 15:12:04 +0100 Subject: [PATCH] Make tokenizer optional (#12) --- README.md | 1 + router/src/main.rs | 10 +++- server/tests/utils/test_tokens.py | 34 +++++++++++++- .../models/causal_lm.py | 17 ++++++- .../text_generation_server/utils/__init__.py | 2 + server/text_generation_server/utils/tokens.py | 46 +++++++++++++++++++ 6 files changed, 106 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e1bf94ff..f63cd319 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ Environment Variables Added: | LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command | | BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | | PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | +| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command | | TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command | diff --git a/router/src/main.rs b/router/src/main.rs index d90632ef..c4406ab6 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -7,6 +7,7 @@ use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; +use std::env; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::Duration; @@ -141,7 +142,10 @@ fn main() -> Result<(), RouterError> { // This will only be used to validate payloads let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - let tokenizer = if local_model { + let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI").ok().map_or(false, |value| value.to_lowercase() == "true"); + let tokenizer = if skip_tokenizer_in_tgi { + None + } else if local_model { // Load local tokenizer Tokenizer::from_file(local_path.join("tokenizer.json")).ok() } else { @@ -162,7 +166,9 @@ fn main() -> Result<(), RouterError> { .block_on(async { init_logging(otlp_endpoint, json_output); - if tokenizer.is_none() { + if skip_tokenizer_in_tgi { + tracing::warn!("Rust input length validation disabled by environment variable"); + } else if tokenizer.is_none() { tracing::warn!( "Could not find a fast tokenizer implementation for {tokenizer_name}" ); diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 0585f1fb..8ba775a7 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -4,8 +4,9 @@ from text_generation_server.utils.tokens import ( StoppingCriteria, FinishReason, batch_top_tokens, + make_tokenizer_optional, ) - +from transformers import AutoTokenizer def test_stop_sequence_criteria(): criteria = StopSequenceCriteria("/test;") @@ -66,3 +67,34 @@ def test_batch_top_tokens(): assert topn_tok_logprobs[2] == [-1, -2, -3, -3] assert topn_tok_logprobs[3] == [-1, -2, -3, -3] assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4] + + + +def test_pass_through_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained( + 'meta-llama/Llama-2-7b-chat-hf', + revision=None, + padding_side="left", + truncation_side="left", + ) + tokenizer.pad_token_id = 2 + make_tokenizer_optional(tokenizer) + + input = ["1, 1724, 338, 6483, 6509, 29973", "?"] + tokenized_inputs = tokenizer( + input, + return_tensors="pt", + padding="max_length", + return_token_type_ids=False, + truncation=True, + max_length=1024, + ) + assert tokenized_inputs['input_ids'].size() == torch.Size([2, 1024]) + assert torch.equal(tokenized_inputs['input_ids'][0][1018:], torch.tensor([1, 1724, 338, 6483, 6509, 29973])) + assert torch.equal(tokenized_inputs['input_ids'][1][1023:], torch.tensor([tokenizer.pad_token_id])) + decoded_tokens = tokenizer.decode(tokenized_inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + assert decoded_tokens.split(',')[1018:] == ['1', '1724', '338', '6483', '6509', '29973'] + + +if __name__ == "__main__": + test_pass_through_tokenizer() \ No newline at end of file diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ec0793f7..bc4a0366 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -32,7 +32,7 @@ from text_generation_server.models.types import ( TopTokens, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent from loguru import logger tracer = trace.get_tracer(__name__) @@ -141,6 +141,7 @@ class CausalLMRequest: self.idx = new_idx return (new_idx, prev) + @dataclass class CausalLMBatch(Batch): batch_id: int @@ -446,6 +447,7 @@ class CausalLM(Model): padding_side="left", truncation_side="left", ) + make_tokenizer_optional(tokenizer) model_kwargs = { "revision": revision, @@ -562,6 +564,19 @@ class CausalLM(Model): def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + def decode_token( + self, + all_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: + if is_tokenizer_transparent(self.tokenizer): + new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) + return new_text, read_offset, len(all_input_ids) + else: + return super().decode_token(all_input_ids, prefix_offset, read_offset) + + def forward( self, input_ids, diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 08ba808d..6acbe22d 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -18,6 +18,8 @@ from text_generation_server.utils.tokens import ( FinishReason, Sampling, Greedy, + make_tokenizer_optional, + is_tokenizer_transparent ) __all__ = [ diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index f9797195..d82a7f80 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import ( ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +from transformers.utils import to_py_obj class NextTokenChooser: @@ -359,3 +360,48 @@ def batch_top_tokens( [idxs[:n] if req_n > 0 else [] for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)], [vals[:n] if req_n > 0 else [] for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)], ) + + +def make_tokenizer_optional(tokenizer): + class _(type(tokenizer)): + def __call__( + self, + text, + return_tensors, + padding, + return_token_type_ids, + truncation, + max_length + ): + assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" + assert padding == "max_length", "inccorrect input arguments when calling TransparentTokenizer" + assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer" + assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer" + + def str_token_to_int(i): + if i == '?': + return tokenizer.pad_token_id + else: + return int(i) + all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] + for inner_text in text] + return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens]), + "attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens])} + + def decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + return ','.join(str(i) for i in to_py_obj(token_ids)) + + import os + if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": + tokenizer.__class__ = _ + tokenizer.is_transparent = True + + +def is_tokenizer_transparent(tokenizer): + return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True \ No newline at end of file