From c922ef95348454a22eaa0a5fa2f8ff35a96d2f7a Mon Sep 17 00:00:00 2001 From: yuanwu Date: Mon, 9 Dec 2024 07:20:48 +0000 Subject: [PATCH] Fix the warmup issue of llama2-7B. Signed-off-by: yuanwu --- .../models/causal_lm.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 273cf3d44..120b140bf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1184,6 +1184,7 @@ class CausalLM(Model): MAX_TOTAL_TOKENS = request.max_total_tokens MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device) + max_prefill_batch_size = batch.input_ids.shape[0] try: # max prefill batch size warmup _, prefill_batch, _ = self.generate_token([batch]) @@ -1192,9 +1193,9 @@ class CausalLM(Model): f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) + del prefill_batch - max_prefill_batch_size = batch.input_ids.shape[0] # Warmup prefill batch_size max_input_length = request.max_input_length prefill_batch_size_list = [batch for batch in range(BATCH_BUCKET_SIZE, max_prefill_batch_size, BATCH_BUCKET_SIZE)] @@ -1221,7 +1222,7 @@ class CausalLM(Model): prefill_batch_size_list.sort() mem_stats = get_hpu_memory_stats(self.device) logger.info( - f"\nFollowing prefill and decode warmup successfully.\n" + f"\nFollowing prefill warmup successfully.\n" f"Prefill batch size list:{prefill_batch_size_list}\n" f"Prefill sequence length list:{prefill_seqlen_list}\n" f"Memory stats: {mem_stats} " @@ -1235,23 +1236,23 @@ class CausalLM(Model): decode_batch_size_list.sort(reverse=True) try: - for i in range(2): - for batch_size in decode_batch_size_list: - batches= [] - iters = math.floor(batch_size/max_prefill_batch_size) - for i in range(iters): - batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) + for batch_size in decode_batch_size_list: + batches= [] + iters = math.floor(batch_size/max_prefill_batch_size) + for i in range(iters): + batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + batches.append(prefill_batch) - if batch_size % max_prefill_batch_size != 0: - batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) + if batch_size % max_prefill_batch_size != 0: + batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches) - del decode_batch - batches.clear() + _, decode_batch, _ = self.generate_token(batches) + _, decode_batch, _ = self.generate_token([decode_batch]) + del decode_batch + batches.clear() except: raise RuntimeError( @@ -1268,4 +1269,4 @@ class CausalLM(Model): f"Memory stats: {mem_stats} " ) - return MAX_BATCH_TOTAL_TOKENS + return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file