mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-02 15:32:13 +00:00
Refine the warmup process
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
253a992447
commit
9f356ce045
@ -516,7 +516,7 @@ class CausalLMBatch(Batch):
|
|||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
rounded_seq_len = round_up(input_len + 1, PREFILL_BATCH_BUCKET_SIZE)
|
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||||
if rounded_seq_len <= max_input_length:
|
if rounded_seq_len <= max_input_length:
|
||||||
bucket_size = rounded_seq_len - 1
|
bucket_size = rounded_seq_len - 1
|
||||||
else:
|
else:
|
||||||
@ -1193,9 +1193,41 @@ class CausalLM(Model):
|
|||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
)
|
)
|
||||||
del prefill_batch
|
del prefill_batch
|
||||||
#warmup decode batch size
|
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
del batch
|
# Warmup prefill batch_size
|
||||||
|
max_input_length = request.max_input_length
|
||||||
|
prefill_batch_size_list = [batch for batch in range(BATCH_BUCKET_SIZE, max_prefill_batch_size, BATCH_BUCKET_SIZE)]
|
||||||
|
prefill_batch_size_list.append(max_prefill_batch_size)
|
||||||
|
prefill_seqlen_list = [seq for seq in range(PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)]
|
||||||
|
prefill_seqlen_list.append(max_input_length)
|
||||||
|
prefill_batch_size_list.sort(reverse=True)
|
||||||
|
prefill_seqlen_list.sort(reverse=True)
|
||||||
|
try:
|
||||||
|
for batch_size in prefill_batch_size_list:
|
||||||
|
for seq_len in prefill_seqlen_list:
|
||||||
|
batch = self.generate_warmup_batch(request, seq_len-1, batch_size)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
|
except:
|
||||||
|
prefill_batch_size_list.sort()
|
||||||
|
prefill_seqlen_list.sort()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not enough memory to run following prefill batch_size."
|
||||||
|
f"Prefill batch size list:{prefill_batch_size_list}"
|
||||||
|
f"Prefill sequence length list:{prefill_seqlen_list}"
|
||||||
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
|
)
|
||||||
|
prefill_seqlen_list.sort()
|
||||||
|
prefill_batch_size_list.sort()
|
||||||
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
|
logger.info(
|
||||||
|
f"\nFollowing prefill and decode warmup successfully.\n"
|
||||||
|
f"Prefill batch size list:{prefill_batch_size_list}\n"
|
||||||
|
f"Prefill sequence length list:{prefill_seqlen_list}\n"
|
||||||
|
f"Memory stats: {mem_stats} "
|
||||||
|
)
|
||||||
|
|
||||||
|
#warmup decode batch size
|
||||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||||
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
|
||||||
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
|
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
|
||||||
@ -1203,66 +1235,37 @@ class CausalLM(Model):
|
|||||||
decode_batch_size_list.sort(reverse=True)
|
decode_batch_size_list.sort(reverse=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for batch_size in decode_batch_size_list:
|
for i in range(2):
|
||||||
batches= []
|
for batch_size in decode_batch_size_list:
|
||||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
batches= []
|
||||||
for i in range(iters):
|
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
for i in range(iters):
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
||||||
batches.append(prefill_batch)
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
if batch_size % max_prefill_batch_size != 0:
|
if batch_size % max_prefill_batch_size != 0:
|
||||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
_, prefill_batch, _ = self.generate_token([batch])
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
|
|
||||||
|
_, decode_batch, _ = self.generate_token(batches)
|
||||||
|
del decode_batch
|
||||||
|
batches.clear()
|
||||||
|
|
||||||
_, decode_batch, _ = self.generate_token(batches)
|
|
||||||
del decode_batch
|
|
||||||
batches.clear()
|
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
|
||||||
f"You need to decrease `--max-batch-total-tokens`"
|
f"You need to decrease `--max-batch-total-tokens`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decode_batch_size_list.sort()
|
decode_batch_size_list.sort()
|
||||||
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
|
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
|
||||||
mem_stats = get_hpu_memory_stats(self.device)
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"\nFollowing decode warmup successfully.\n"
|
f"\nFollowing decode warmup successfully.\n"
|
||||||
f"Decode batch size list:{decode_batch_size_list}\n"
|
f"Decode batch size list:{decode_batch_size_list}\n"
|
||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
)
|
)
|
||||||
|
|
||||||
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
|
||||||
if limit_hpu_graph == False:
|
|
||||||
# Warmup prefill batch_size
|
|
||||||
max_input_length = request.max_input_length
|
|
||||||
prefill_batch_size_list = []
|
|
||||||
prefill_seqlen_list = []
|
|
||||||
try:
|
|
||||||
for batch_size in range(max_prefill_batch_size, 0, -PREFILL_BATCH_BUCKET_SIZE):
|
|
||||||
prefill_batch_size_list.append(batch_size)
|
|
||||||
for seq_len in range(max_input_length, 0, -PAD_SEQUENCE_TO_MULTIPLE_OF):
|
|
||||||
prefill_seqlen_list.append(seq_len)
|
|
||||||
batch = self.generate_warmup_batch(request, seq_len, batch_size)
|
|
||||||
_, prefill_batch, _ = self.generate_token([batch])
|
|
||||||
del batch
|
|
||||||
del prefill_batch
|
|
||||||
except:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Not enough memory to run following prefill batch_size."
|
|
||||||
f"Prefill batch size list:{prefill_batch_size_list}"
|
|
||||||
f"Prefill sequence length list:{prefill_seqlen_list}"
|
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
|
||||||
)
|
|
||||||
prefill_batch_size_list.sort()
|
|
||||||
prefill_seqlen_list.sort()
|
|
||||||
mem_stats = get_hpu_memory_stats(self.device)
|
|
||||||
logger.info(
|
|
||||||
f"\nFollowing prefill and decode warmup successfully.\n"
|
|
||||||
f"Prefill batch size list:{prefill_batch_size_list}\n"
|
|
||||||
f"Prefill sequence length list:{prefill_seqlen_list}\n"
|
|
||||||
f"Memory stats: {mem_stats} "
|
|
||||||
)
|
|
||||||
|
|
||||||
return MAX_BATCH_TOTAL_TOKENS
|
return MAX_BATCH_TOTAL_TOKENS
|
||||||
|
Loading…
Reference in New Issue
Block a user