From a490847702d8141668e53dbbacbce5d2ff3c7bee Mon Sep 17 00:00:00 2001 From: jkaniecki <153085639+jkaniecki@users.noreply.github.com> Date: Fri, 23 Feb 2024 01:52:14 +0100 Subject: [PATCH] Sequence bucketing for prefill (#39) (#67) Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com> --- README.md | 30 +++++++++--------- .../models/causal_lm.py | 31 ++++++++++++------- server/text_generation_server/utils/tokens.py | 10 +++--- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index ca9d9a41..0e2b32db 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ New changes are added for the current release: - Sharded feature with support for DeepSpeed-inference auto tensor parallelism. Also, use HPU graphs for performance improvement. - Torch profile. - Batch size bucketing for decode and prefill. +- Sequence bucketing for prefill. @@ -69,20 +70,21 @@ Environment Variables Added:
-| Name | Value(s) | Default | Description | Usage | -|------------------ |:---------------|:------------|:-------------------- |:--------------------------------- -| MAX_TOTAL_TOKENS | integer | 0 | Control the padding of input | add -e in docker run, such | -| ENABLE_HPU_GRAPH | true/false | true | Enable hpu graph or not | add -e in docker run command | -| PROF_WARMUPSTEP | integer | 0 | Enable/disable profile, control profile warmup step, 0 means disable profile | add -e in docker run command | -| PROF_STEP | integer | 5 | Control profile step | add -e in docker run command | -| PROF_PATH | string | /tmp/hpu_profile | Define profile folder | add -e in docker run command | -| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile | add -e in docker run command | -| PROF_RECORD_SHAPES | true/false | false | Control record_shapes option in the profiler | add -e in docker run command | -| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command | -| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | -| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | -| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command | -| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command | +| Name | Value(s) | Default | Description | Usage | +| --------------------------- | :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- | +| MAX_TOTAL_TOKENS | integer | 0 | Control the padding of input | add -e in docker run, such | +| ENABLE_HPU_GRAPH | true/false | true | Enable hpu graph or not | add -e in docker run command | +| PROF_WARMUPSTEP | integer | 0 | Enable/disable profile, control profile warmup step, 0 means disable profile | add -e in docker run command | +| PROF_STEP | integer | 5 | Control profile step | add -e in docker run command | +| PROF_PATH | string | /tmp/hpu_profile | Define profile folder | add -e in docker run command | +| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile | add -e in docker run command | +| PROF_RECORD_SHAPES | true/false | false | Control record_shapes option in the profiler | add -e in docker run command | +| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command | +| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | +| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | +| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command | +| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command | +| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index dbe7d616..a0b1bd82 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -44,6 +44,7 @@ if 'GRAPH_VISUALIZATION' in os.environ: os.remove(f) BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') START_TS = None @@ -376,21 +377,24 @@ class CausalLMBatch(Batch): tokenized_inputs = tokenizer( [r.data.inputs for r in requests] + dummy_inputs, return_tensors="pt", - padding="max_length", + padding="longest", return_token_type_ids=False, truncation=True, max_length=max_input_length, ) input_len = tokenized_inputs["input_ids"].shape[1] + + bucket_size = max_input_length + left_padding = max_input_length - input_len + if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: + assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" + bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1 + left_padding = bucket_size - input_len + extra_padding = 0 if is_optimized_for_gaudi and max_total_tokens > 0: - extra_padding = max(extra_padding, max_total_tokens - max_input_length - max_new_tokens) - - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len + extra_padding = max(extra_padding, max_total_tokens - (bucket_size + 1) - max_new_tokens) input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] @@ -398,18 +402,23 @@ class CausalLMBatch(Batch): if is_optimized_for_gaudi: # Allocate space for first token input_ids = torch.nn.functional.pad( - input_ids, (0, 1), value=tokenizer.pad_token_id + input_ids, (left_padding, 1), value=tokenizer.pad_token_id ) attention_mask = torch.nn.functional.pad( - attention_mask, (0, 1), value=0 + attention_mask, (left_padding, 1), value=0 ) all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens + extra_padding - 1), value=tokenizer.pad_token_id + input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id ).T.split(1, dim=1) else: all_input_ids = input_ids.clone().T.split(1, dim=1) + # New input length after left padding + input_len = bucket_size for r in requests: + r.input_length = input_len + r.prefix_offset = input_len - 5 + r.read_offset = input_len r.all_input_ids = all_input_ids[r.idx] input_ids = input_ids.to(device) @@ -429,7 +438,7 @@ class CausalLMBatch(Batch): next_token_chooser=next_token_chooser, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - input_length=max_input_length, + input_length=input_len, right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0 ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 6d95af94..2814ced8 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -374,7 +374,7 @@ def make_tokenizer_optional(tokenizer): max_length ): assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" - assert padding == "max_length", "inccorrect input arguments when calling TransparentTokenizer" + assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer" assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer" assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer" @@ -385,8 +385,10 @@ def make_tokenizer_optional(tokenizer): return int(i) all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] for inner_text in text] - return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens], dtype=torch.int32), - "attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens], dtype=torch.int32)} + if padding == "longest": + max_length = max(len(tokens) for tokens in all_tokens) + return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens], dtype=torch.int32), + "attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens], dtype=torch.int32)} def decode( self, @@ -404,4 +406,4 @@ def make_tokenizer_optional(tokenizer): def is_tokenizer_transparent(tokenizer): - return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True \ No newline at end of file + return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True