mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-16 15:12:09 +00:00
Simplify the warmup
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
8686a0fc6d
commit
8ebe77b3be
@ -1213,10 +1213,9 @@ class CausalLM(Model):
|
||||
#warmup decode batch size
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||
batch_size = max_decode_batch_size
|
||||
self.limit_hpu_graph = True
|
||||
try:
|
||||
while batch_size > 1:
|
||||
for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE):
|
||||
batches= []
|
||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
@ -1231,8 +1230,6 @@ class CausalLM(Model):
|
||||
batches.append(batch)
|
||||
|
||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||
logger.info(f"DECODE_DIVISOR={BATCH_BUCKET_SIZE}")
|
||||
batch_size = math.floor(batch_size / BATCH_BUCKET_SIZE)
|
||||
except:
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
|
||||
self.model.clear_cache()
|
||||
@ -1257,26 +1254,21 @@ class CausalLM(Model):
|
||||
# Warmup prefill batch_size
|
||||
max_input_length = request.max_input_length
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
batch_size = max_prefill_batch_size
|
||||
while batch_size >= 1:
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
batch_size = math.floor(batch_size / PREFILL_BATCH_BUCKET_SIZE)
|
||||
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||
|
||||
seq_len = max_input_length
|
||||
while seq_len >= PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||
i = 0
|
||||
while seq_len <= max_input_length:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
||||
seq_len = math.floor(seq_len/2)
|
||||
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
||||
i += 1
|
||||
|
||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] > PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
||||
|
||||
#Prefill and decode warmup
|
||||
prefill_batch = None
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST.sort()
|
||||
PREFILL_WARMUP_SEQLEN_LIST.sort()
|
||||
|
||||
try:
|
||||
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
||||
for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE):
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
||||
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
|
Loading…
Reference in New Issue
Block a user