Fix the warmup issue of llama2-7B.

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-12-09 07:20:48 +00:00
parent c6f023a06b
commit c922ef9534

View File

@ -1184,6 +1184,7 @@ class CausalLM(Model):
MAX_TOTAL_TOKENS = request.max_total_tokens MAX_TOTAL_TOKENS = request.max_total_tokens
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device) batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
max_prefill_batch_size = batch.input_ids.shape[0]
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch]) _, prefill_batch, _ = self.generate_token([batch])
@ -1192,9 +1193,9 @@ class CausalLM(Model):
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) )
del prefill_batch del prefill_batch
max_prefill_batch_size = batch.input_ids.shape[0]
# Warmup prefill batch_size # Warmup prefill batch_size
max_input_length = request.max_input_length max_input_length = request.max_input_length
prefill_batch_size_list = [batch for batch in range(BATCH_BUCKET_SIZE, max_prefill_batch_size, BATCH_BUCKET_SIZE)] prefill_batch_size_list = [batch for batch in range(BATCH_BUCKET_SIZE, max_prefill_batch_size, BATCH_BUCKET_SIZE)]
@ -1221,7 +1222,7 @@ class CausalLM(Model):
prefill_batch_size_list.sort() prefill_batch_size_list.sort()
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing prefill and decode warmup successfully.\n" f"\nFollowing prefill warmup successfully.\n"
f"Prefill batch size list:{prefill_batch_size_list}\n" f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n" f"Prefill sequence length list:{prefill_seqlen_list}\n"
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
@ -1235,7 +1236,6 @@ class CausalLM(Model):
decode_batch_size_list.sort(reverse=True) decode_batch_size_list.sort(reverse=True)
try: try:
for i in range(2):
for batch_size in decode_batch_size_list: for batch_size in decode_batch_size_list:
batches= [] batches= []
iters = math.floor(batch_size/max_prefill_batch_size) iters = math.floor(batch_size/max_prefill_batch_size)
@ -1250,6 +1250,7 @@ class CausalLM(Model):
batches.append(prefill_batch) batches.append(prefill_batch)
_, decode_batch, _ = self.generate_token(batches) _, decode_batch, _ = self.generate_token(batches)
_, decode_batch, _ = self.generate_token([decode_batch])
del decode_batch del decode_batch
batches.clear() batches.clear()