mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Make tokenizer optional (#12)
This commit is contained in:
parent
381ec38cad
commit
0b96da89aa
@ -79,6 +79,7 @@ Environment Variables Added:
|
||||
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
|
||||
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
||||
|
||||
</div>
|
||||
|
@ -7,6 +7,7 @@ use opentelemetry::sdk::trace::Sampler;
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use std::env;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
@ -141,7 +142,10 @@ fn main() -> Result<(), RouterError> {
|
||||
// This will only be used to validate payloads
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
let local_model = local_path.exists() && local_path.is_dir();
|
||||
let tokenizer = if local_model {
|
||||
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI").ok().map_or(false, |value| value.to_lowercase() == "true");
|
||||
let tokenizer = if skip_tokenizer_in_tgi {
|
||||
None
|
||||
} else if local_model {
|
||||
// Load local tokenizer
|
||||
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
||||
} else {
|
||||
@ -162,7 +166,9 @@ fn main() -> Result<(), RouterError> {
|
||||
.block_on(async {
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
if tokenizer.is_none() {
|
||||
if skip_tokenizer_in_tgi {
|
||||
tracing::warn!("Rust input length validation disabled by environment variable");
|
||||
} else if tokenizer.is_none() {
|
||||
tracing::warn!(
|
||||
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
||||
);
|
||||
|
@ -4,8 +4,9 @@ from text_generation_server.utils.tokens import (
|
||||
StoppingCriteria,
|
||||
FinishReason,
|
||||
batch_top_tokens,
|
||||
make_tokenizer_optional,
|
||||
)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
def test_stop_sequence_criteria():
|
||||
criteria = StopSequenceCriteria("/test;")
|
||||
@ -66,3 +67,34 @@ def test_batch_top_tokens():
|
||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||
|
||||
|
||||
|
||||
def test_pass_through_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
'meta-llama/Llama-2-7b-chat-hf',
|
||||
revision=None,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
)
|
||||
tokenizer.pad_token_id = 2
|
||||
make_tokenizer_optional(tokenizer)
|
||||
|
||||
input = ["1, 1724, 338, 6483, 6509, 29973", "?"]
|
||||
tokenized_inputs = tokenizer(
|
||||
input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
max_length=1024,
|
||||
)
|
||||
assert tokenized_inputs['input_ids'].size() == torch.Size([2, 1024])
|
||||
assert torch.equal(tokenized_inputs['input_ids'][0][1018:], torch.tensor([1, 1724, 338, 6483, 6509, 29973]))
|
||||
assert torch.equal(tokenized_inputs['input_ids'][1][1023:], torch.tensor([tokenizer.pad_token_id]))
|
||||
decoded_tokens = tokenizer.decode(tokenized_inputs["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
assert decoded_tokens.split(',')[1018:] == ['1', '1724', '338', '6483', '6509', '29973']
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pass_through_tokenizer()
|
@ -32,7 +32,7 @@ from text_generation_server.models.types import (
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
|
||||
from loguru import logger
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -141,6 +141,7 @@ class CausalLMRequest:
|
||||
self.idx = new_idx
|
||||
return (new_idx, prev)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
@ -446,6 +447,7 @@ class CausalLM(Model):
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
)
|
||||
make_tokenizer_optional(tokenizer)
|
||||
|
||||
model_kwargs = {
|
||||
"revision": revision,
|
||||
@ -562,6 +564,19 @@ class CausalLM(Model):
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
def decode_token(
|
||||
self,
|
||||
all_input_ids: List[int],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
) -> Tuple[str, int, int]:
|
||||
if is_tokenizer_transparent(self.tokenizer):
|
||||
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
|
||||
return new_text, read_offset, len(all_input_ids)
|
||||
else:
|
||||
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -18,6 +18,8 @@ from text_generation_server.utils.tokens import (
|
||||
FinishReason,
|
||||
Sampling,
|
||||
Greedy,
|
||||
make_tokenizer_optional,
|
||||
is_tokenizer_transparent
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import (
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
from transformers.utils import to_py_obj
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -359,3 +360,48 @@ def batch_top_tokens(
|
||||
[idxs[:n] if req_n > 0 else [] for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)],
|
||||
[vals[:n] if req_n > 0 else [] for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)],
|
||||
)
|
||||
|
||||
|
||||
def make_tokenizer_optional(tokenizer):
|
||||
class _(type(tokenizer)):
|
||||
def __call__(
|
||||
self,
|
||||
text,
|
||||
return_tensors,
|
||||
padding,
|
||||
return_token_type_ids,
|
||||
truncation,
|
||||
max_length
|
||||
):
|
||||
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer"
|
||||
assert padding == "max_length", "inccorrect input arguments when calling TransparentTokenizer"
|
||||
assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer"
|
||||
assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer"
|
||||
|
||||
def str_token_to_int(i):
|
||||
if i == '?':
|
||||
return tokenizer.pad_token_id
|
||||
else:
|
||||
return int(i)
|
||||
all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')]
|
||||
for inner_text in text]
|
||||
return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens]),
|
||||
"attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens])}
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
return ','.join(str(i) for i in to_py_obj(token_ids))
|
||||
|
||||
import os
|
||||
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
|
||||
tokenizer.__class__ = _
|
||||
tokenizer.is_transparent = True
|
||||
|
||||
|
||||
def is_tokenizer_transparent(tokenizer):
|
||||
return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True
|
Loading…
Reference in New Issue
Block a user