Make tokenizer optional (#12)

This commit is contained in:
Adam Stachowicz 2024-01-19 15:12:04 +01:00 committed by GitHub
parent 381ec38cad
commit 0b96da89aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 106 additions and 4 deletions

View File

@ -79,6 +79,7 @@ Environment Variables Added:
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
</div>

View File

@ -7,6 +7,7 @@ use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::env;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use std::time::Duration;
@ -141,7 +142,10 @@ fn main() -> Result<(), RouterError> {
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
let tokenizer = if local_model {
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI").ok().map_or(false, |value| value.to_lowercase() == "true");
let tokenizer = if skip_tokenizer_in_tgi {
None
} else if local_model {
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
} else {
@ -162,7 +166,9 @@ fn main() -> Result<(), RouterError> {
.block_on(async {
init_logging(otlp_endpoint, json_output);
if tokenizer.is_none() {
if skip_tokenizer_in_tgi {
tracing::warn!("Rust input length validation disabled by environment variable");
} else if tokenizer.is_none() {
tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
);

View File

@ -4,8 +4,9 @@ from text_generation_server.utils.tokens import (
StoppingCriteria,
FinishReason,
batch_top_tokens,
make_tokenizer_optional,
)
from transformers import AutoTokenizer
def test_stop_sequence_criteria():
criteria = StopSequenceCriteria("/test;")
@ -66,3 +67,34 @@ def test_batch_top_tokens():
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
def test_pass_through_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
'meta-llama/Llama-2-7b-chat-hf',
revision=None,
padding_side="left",
truncation_side="left",
)
tokenizer.pad_token_id = 2
make_tokenizer_optional(tokenizer)
input = ["1, 1724, 338, 6483, 6509, 29973", "?"]
tokenized_inputs = tokenizer(
input,
return_tensors="pt",
padding="max_length",
return_token_type_ids=False,
truncation=True,
max_length=1024,
)
assert tokenized_inputs['input_ids'].size() == torch.Size([2, 1024])
assert torch.equal(tokenized_inputs['input_ids'][0][1018:], torch.tensor([1, 1724, 338, 6483, 6509, 29973]))
assert torch.equal(tokenized_inputs['input_ids'][1][1023:], torch.tensor([tokenizer.pad_token_id]))
decoded_tokens = tokenizer.decode(tokenized_inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
assert decoded_tokens.split(',')[1018:] == ['1', '1724', '338', '6483', '6509', '29973']
if __name__ == "__main__":
test_pass_through_tokenizer()

View File

@ -32,7 +32,7 @@ from text_generation_server.models.types import (
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
from loguru import logger
tracer = trace.get_tracer(__name__)
@ -141,6 +141,7 @@ class CausalLMRequest:
self.idx = new_idx
return (new_idx, prev)
@dataclass
class CausalLMBatch(Batch):
batch_id: int
@ -446,6 +447,7 @@ class CausalLM(Model):
padding_side="left",
truncation_side="left",
)
make_tokenizer_optional(tokenizer)
model_kwargs = {
"revision": revision,
@ -562,6 +564,19 @@ class CausalLM(Model):
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
def decode_token(
self,
all_input_ids: List[int],
prefix_offset: int = 0,
read_offset: int = 0,
) -> Tuple[str, int, int]:
if is_tokenizer_transparent(self.tokenizer):
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
return new_text, read_offset, len(all_input_ids)
else:
return super().decode_token(all_input_ids, prefix_offset, read_offset)
def forward(
self,
input_ids,

View File

@ -18,6 +18,8 @@ from text_generation_server.utils.tokens import (
FinishReason,
Sampling,
Greedy,
make_tokenizer_optional,
is_tokenizer_transparent
)
__all__ = [

View File

@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import (
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from transformers.utils import to_py_obj
class NextTokenChooser:
@ -359,3 +360,48 @@ def batch_top_tokens(
[idxs[:n] if req_n > 0 else [] for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)],
[vals[:n] if req_n > 0 else [] for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)],
)
def make_tokenizer_optional(tokenizer):
class _(type(tokenizer)):
def __call__(
self,
text,
return_tensors,
padding,
return_token_type_ids,
truncation,
max_length
):
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer"
assert padding == "max_length", "inccorrect input arguments when calling TransparentTokenizer"
assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer"
assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer"
def str_token_to_int(i):
if i == '?':
return tokenizer.pad_token_id
else:
return int(i)
all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')]
for inner_text in text]
return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens]),
"attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens])}
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
return ','.join(str(i) for i in to_py_obj(token_ids))
import os
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
tokenizer.__class__ = _
tokenizer.is_transparent = True
def is_tokenizer_transparent(tokenizer):
return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True