From 9f356ce0452431c14229e7313cd4f190ae641988 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Sat, 7 Dec 2024 09:56:16 +0000 Subject: [PATCH] Refine the warmup process Signed-off-by: yuanwu --- .../models/causal_lm.py | 109 +++++++++--------- 1 file changed, 56 insertions(+), 53 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ad40097e..273cf3d4 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -516,7 +516,7 @@ class CausalLMBatch(Batch): left_padding = max_input_length - input_len if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up(input_len + 1, PREFILL_BATCH_BUCKET_SIZE) + rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 else: @@ -1193,9 +1193,41 @@ class CausalLM(Model): f"You need to decrease `--max-batch-prefill-tokens`" ) del prefill_batch - #warmup decode batch size + max_prefill_batch_size = batch.input_ids.shape[0] - del batch + # Warmup prefill batch_size + max_input_length = request.max_input_length + prefill_batch_size_list = [batch for batch in range(BATCH_BUCKET_SIZE, max_prefill_batch_size, BATCH_BUCKET_SIZE)] + prefill_batch_size_list.append(max_prefill_batch_size) + prefill_seqlen_list = [seq for seq in range(PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)] + prefill_seqlen_list.append(max_input_length) + prefill_batch_size_list.sort(reverse=True) + prefill_seqlen_list.sort(reverse=True) + try: + for batch_size in prefill_batch_size_list: + for seq_len in prefill_seqlen_list: + batch = self.generate_warmup_batch(request, seq_len-1, batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + except: + prefill_batch_size_list.sort() + prefill_seqlen_list.sort() + raise RuntimeError( + f"Not enough memory to run following prefill batch_size." + f"Prefill batch size list:{prefill_batch_size_list}" + f"Prefill sequence length list:{prefill_seqlen_list}" + f"You need to decrease `--max-batch-prefill-tokens`" + ) + prefill_seqlen_list.sort() + prefill_batch_size_list.sort() + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing prefill and decode warmup successfully.\n" + f"Prefill batch size list:{prefill_batch_size_list}\n" + f"Prefill sequence length list:{prefill_seqlen_list}\n" + f"Memory stats: {mem_stats} " + ) + + #warmup decode batch size max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)] @@ -1203,66 +1235,37 @@ class CausalLM(Model): decode_batch_size_list.sort(reverse=True) try: - for batch_size in decode_batch_size_list: - batches= [] - iters = math.floor(batch_size/max_prefill_batch_size) - for i in range(iters): - batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) + for i in range(2): + for batch_size in decode_batch_size_list: + batches= [] + iters = math.floor(batch_size/max_prefill_batch_size) + for i in range(iters): + batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + batches.append(prefill_batch) - if batch_size % max_prefill_batch_size != 0: - batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) + if batch_size % max_prefill_batch_size != 0: + batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + batches.append(prefill_batch) + + _, decode_batch, _ = self.generate_token(batches) + del decode_batch + batches.clear() - _, decode_batch, _ = self.generate_token(batches) - del decode_batch - batches.clear() except: raise RuntimeError( f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." f"You need to decrease `--max-batch-total-tokens`" ) + decode_batch_size_list.sort() MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] mem_stats = get_hpu_memory_stats(self.device) logger.info( - f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{decode_batch_size_list}\n" - f"Memory stats: {mem_stats} " - ) - - limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" - if limit_hpu_graph == False: - # Warmup prefill batch_size - max_input_length = request.max_input_length - prefill_batch_size_list = [] - prefill_seqlen_list = [] - try: - for batch_size in range(max_prefill_batch_size, 0, -PREFILL_BATCH_BUCKET_SIZE): - prefill_batch_size_list.append(batch_size) - for seq_len in range(max_input_length, 0, -PAD_SEQUENCE_TO_MULTIPLE_OF): - prefill_seqlen_list.append(seq_len) - batch = self.generate_warmup_batch(request, seq_len, batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - del batch - del prefill_batch - except: - raise RuntimeError( - f"Not enough memory to run following prefill batch_size." - f"Prefill batch size list:{prefill_batch_size_list}" - f"Prefill sequence length list:{prefill_seqlen_list}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - prefill_batch_size_list.sort() - prefill_seqlen_list.sort() - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing prefill and decode warmup successfully.\n" - f"Prefill batch size list:{prefill_batch_size_list}\n" - f"Prefill sequence length list:{prefill_seqlen_list}\n" - f"Memory stats: {mem_stats} " - ) + f"\nFollowing decode warmup successfully.\n" + f"Decode batch size list:{decode_batch_size_list}\n" + f"Memory stats: {mem_stats} " + ) return MAX_BATCH_TOTAL_TOKENS