mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-20 17:22:09 +00:00
Change HPU warmup logic: seq length should be with exponential growth (#3217)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
This commit is contained in:
parent
56c8189467
commit
c94f415af4
@ -56,6 +56,7 @@ PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF",
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
|
||||
SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2))
|
||||
MAX_BATCH_SIZE = (
|
||||
int(os.environ.get("MAX_BATCH_SIZE"))
|
||||
if os.environ.get("MAX_BATCH_SIZE") is not None
|
||||
@ -71,8 +72,21 @@ def torch_compile_for_eager(func):
|
||||
)
|
||||
|
||||
|
||||
def round_up_seq(number, k):
|
||||
return (number + k - 1) // k * k
|
||||
def round_up_seq(number, k, base):
|
||||
exponent = math.ceil(math.log(number / k, base))
|
||||
return int(k * (base**exponent))
|
||||
|
||||
|
||||
def iterate_powers_of_base(max_value, start, base):
|
||||
current = start
|
||||
result = []
|
||||
assert (
|
||||
max_value >= start
|
||||
), f"max_value {max_value} must be greater than start {start}"
|
||||
while current < max_value:
|
||||
result.append(current)
|
||||
current *= base
|
||||
return result
|
||||
|
||||
|
||||
def round_up_batch(number):
|
||||
@ -575,7 +589,9 @@ class CausalLMBatch(Batch):
|
||||
assert (
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
|
||||
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||
rounded_seq_len = round_up_seq(
|
||||
input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
|
||||
)
|
||||
if rounded_seq_len <= max_input_length:
|
||||
bucket_size = rounded_seq_len - 1
|
||||
else:
|
||||
@ -1345,14 +1361,9 @@ class CausalLM(Model):
|
||||
max_exp + 1,
|
||||
)
|
||||
]
|
||||
prefill_seqlen_list = [
|
||||
seq
|
||||
for seq in range(
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
max_input_tokens,
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
)
|
||||
]
|
||||
prefill_seqlen_list = iterate_powers_of_base(
|
||||
max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
|
||||
)
|
||||
prefill_seqlen_list.append(max_input_tokens)
|
||||
prefill_batch_size_list.sort(reverse=True)
|
||||
prefill_seqlen_list.sort(reverse=True)
|
||||
|
Loading…
Reference in New Issue
Block a user