From a490847702d8141668e53dbbacbce5d2ff3c7bee Mon Sep 17 00:00:00 2001
From: jkaniecki <153085639+jkaniecki@users.noreply.github.com>
Date: Fri, 23 Feb 2024 01:52:14 +0100
Subject: [PATCH] Sequence bucketing for prefill (#39) (#67)
Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
---
README.md | 30 +++++++++---------
.../models/causal_lm.py | 31 ++++++++++++-------
server/text_generation_server/utils/tokens.py | 10 +++---
3 files changed, 42 insertions(+), 29 deletions(-)
diff --git a/README.md b/README.md
index ca9d9a41..0e2b32db 100644
--- a/README.md
+++ b/README.md
@@ -62,6 +62,7 @@ New changes are added for the current release:
- Sharded feature with support for DeepSpeed-inference auto tensor parallelism. Also, use HPU graphs for performance improvement.
- Torch profile.
- Batch size bucketing for decode and prefill.
+- Sequence bucketing for prefill.
@@ -69,20 +70,21 @@ Environment Variables Added:
-| Name | Value(s) | Default | Description | Usage |
-|------------------ |:---------------|:------------|:-------------------- |:---------------------------------
-| MAX_TOTAL_TOKENS | integer | 0 | Control the padding of input | add -e in docker run, such |
-| ENABLE_HPU_GRAPH | true/false | true | Enable hpu graph or not | add -e in docker run command |
-| PROF_WARMUPSTEP | integer | 0 | Enable/disable profile, control profile warmup step, 0 means disable profile | add -e in docker run command |
-| PROF_STEP | integer | 5 | Control profile step | add -e in docker run command |
-| PROF_PATH | string | /tmp/hpu_profile | Define profile folder | add -e in docker run command |
-| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile | add -e in docker run command |
-| PROF_RECORD_SHAPES | true/false | false | Control record_shapes option in the profiler | add -e in docker run command |
-| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
-| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
-| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
-| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
-| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
+| Name | Value(s) | Default | Description | Usage |
+| --------------------------- | :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- |
+| MAX_TOTAL_TOKENS | integer | 0 | Control the padding of input | add -e in docker run, such |
+| ENABLE_HPU_GRAPH | true/false | true | Enable hpu graph or not | add -e in docker run command |
+| PROF_WARMUPSTEP | integer | 0 | Enable/disable profile, control profile warmup step, 0 means disable profile | add -e in docker run command |
+| PROF_STEP | integer | 5 | Control profile step | add -e in docker run command |
+| PROF_PATH | string | /tmp/hpu_profile | Define profile folder | add -e in docker run command |
+| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile | add -e in docker run command |
+| PROF_RECORD_SHAPES | true/false | false | Control record_shapes option in the profiler | add -e in docker run command |
+| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
+| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
+| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
+| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
+| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
+| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py
index dbe7d616..a0b1bd82 100644
--- a/server/text_generation_server/models/causal_lm.py
+++ b/server/text_generation_server/models/causal_lm.py
@@ -44,6 +44,7 @@ if 'GRAPH_VISUALIZATION' in os.environ:
os.remove(f)
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
+PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
START_TS = None
@@ -376,21 +377,24 @@ class CausalLMBatch(Batch):
tokenized_inputs = tokenizer(
[r.data.inputs for r in requests] + dummy_inputs,
return_tensors="pt",
- padding="max_length",
+ padding="longest",
return_token_type_ids=False,
truncation=True,
max_length=max_input_length,
)
input_len = tokenized_inputs["input_ids"].shape[1]
+
+ bucket_size = max_input_length
+ left_padding = max_input_length - input_len
+ if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
+ assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
+ bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1
+ left_padding = bucket_size - input_len
+
extra_padding = 0
if is_optimized_for_gaudi and max_total_tokens > 0:
- extra_padding = max(extra_padding, max_total_tokens - max_input_length - max_new_tokens)
-
- for r in requests:
- r.input_length = input_len
- r.prefix_offset = input_len - 5
- r.read_offset = input_len
+ extra_padding = max(extra_padding, max_total_tokens - (bucket_size + 1) - max_new_tokens)
input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"]
@@ -398,18 +402,23 @@ class CausalLMBatch(Batch):
if is_optimized_for_gaudi:
# Allocate space for first token
input_ids = torch.nn.functional.pad(
- input_ids, (0, 1), value=tokenizer.pad_token_id
+ input_ids, (left_padding, 1), value=tokenizer.pad_token_id
)
attention_mask = torch.nn.functional.pad(
- attention_mask, (0, 1), value=0
+ attention_mask, (left_padding, 1), value=0
)
all_input_ids = torch.nn.functional.pad(
- input_ids, (0, max_new_tokens + extra_padding - 1), value=tokenizer.pad_token_id
+ input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id
).T.split(1, dim=1)
else:
all_input_ids = input_ids.clone().T.split(1, dim=1)
+ # New input length after left padding
+ input_len = bucket_size
for r in requests:
+ r.input_length = input_len
+ r.prefix_offset = input_len - 5
+ r.read_offset = input_len
r.all_input_ids = all_input_ids[r.idx]
input_ids = input_ids.to(device)
@@ -429,7 +438,7 @@ class CausalLMBatch(Batch):
next_token_chooser=next_token_chooser,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
- input_length=max_input_length,
+ input_length=input_len,
right_padding=max_new_tokens + extra_padding if is_optimized_for_gaudi else 0
)
diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py
index 6d95af94..2814ced8 100644
--- a/server/text_generation_server/utils/tokens.py
+++ b/server/text_generation_server/utils/tokens.py
@@ -374,7 +374,7 @@ def make_tokenizer_optional(tokenizer):
max_length
):
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer"
- assert padding == "max_length", "inccorrect input arguments when calling TransparentTokenizer"
+ assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer"
assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer"
assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer"
@@ -385,8 +385,10 @@ def make_tokenizer_optional(tokenizer):
return int(i)
all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')]
for inner_text in text]
- return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length-len(tokens)) + tokens for tokens in all_tokens], dtype=torch.int32),
- "attention_mask": torch.tensor([[0] * (max_length-len(tokens)) + [1]*len(tokens) for tokens in all_tokens], dtype=torch.int32)}
+ if padding == "longest":
+ max_length = max(len(tokens) for tokens in all_tokens)
+ return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens], dtype=torch.int32),
+ "attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens], dtype=torch.int32)}
def decode(
self,
@@ -404,4 +406,4 @@ def make_tokenizer_optional(tokenizer):
def is_tokenizer_transparent(tokenizer):
- return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True
\ No newline at end of file
+ return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True