mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 01:32:08 +00:00
Change HPU warmup logic: seq length should be with exponential growth (#3217)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
This commit is contained in:
parent
56c8189467
commit
c94f415af4
@ -56,6 +56,7 @@ PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF",
|
|||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||||
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
|
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
|
||||||
|
SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2))
|
||||||
MAX_BATCH_SIZE = (
|
MAX_BATCH_SIZE = (
|
||||||
int(os.environ.get("MAX_BATCH_SIZE"))
|
int(os.environ.get("MAX_BATCH_SIZE"))
|
||||||
if os.environ.get("MAX_BATCH_SIZE") is not None
|
if os.environ.get("MAX_BATCH_SIZE") is not None
|
||||||
@ -71,8 +72,21 @@ def torch_compile_for_eager(func):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def round_up_seq(number, k):
|
def round_up_seq(number, k, base):
|
||||||
return (number + k - 1) // k * k
|
exponent = math.ceil(math.log(number / k, base))
|
||||||
|
return int(k * (base**exponent))
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_powers_of_base(max_value, start, base):
|
||||||
|
current = start
|
||||||
|
result = []
|
||||||
|
assert (
|
||||||
|
max_value >= start
|
||||||
|
), f"max_value {max_value} must be greater than start {start}"
|
||||||
|
while current < max_value:
|
||||||
|
result.append(current)
|
||||||
|
current *= base
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def round_up_batch(number):
|
def round_up_batch(number):
|
||||||
@ -575,7 +589,9 @@ class CausalLMBatch(Batch):
|
|||||||
assert (
|
assert (
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
|
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
|
||||||
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
rounded_seq_len = round_up_seq(
|
||||||
|
input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
|
||||||
|
)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
else:
|
else:
|
||||||
@ -1345,14 +1361,9 @@ class CausalLM(Model):
|
|||||||
max_exp + 1,
|
max_exp + 1,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
prefill_seqlen_list = [
|
prefill_seqlen_list = iterate_powers_of_base(
|
||||||
seq
|
max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
|
||||||
for seq in range(
|
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
|
||||||
max_input_tokens,
|
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF,
|
|
||||||
)
|
)
|
||||||
]
|
|
||||||
prefill_seqlen_list.append(max_input_tokens)
|
prefill_seqlen_list.append(max_input_tokens)
|
||||||
prefill_batch_size_list.sort(reverse=True)
|
prefill_batch_size_list.sort(reverse=True)
|
||||||
prefill_seqlen_list.sort(reverse=True)
|
prefill_seqlen_list.sort(reverse=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user