diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 88c5debf..39f1f7c9 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1213,10 +1213,9 @@ class CausalLM(Model): #warmup decode batch size max_prefill_batch_size = batch.input_ids.shape[0] max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) - batch_size = max_decode_batch_size self.limit_hpu_graph = True try: - while batch_size > 1: + for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE): batches= [] iters = math.floor(batch_size/max_prefill_batch_size) DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) @@ -1231,8 +1230,6 @@ class CausalLM(Model): batches.append(batch) _, decode_batch, _ = self.generate_token(batches, is_warmup) - logger.info(f"DECODE_DIVISOR={BATCH_BUCKET_SIZE}") - batch_size = math.floor(batch_size / BATCH_BUCKET_SIZE) except: DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1) self.model.clear_cache() @@ -1257,26 +1254,21 @@ class CausalLM(Model): # Warmup prefill batch_size max_input_length = request.max_input_length max_prefill_batch_size = batch.input_ids.shape[0] - batch_size = max_prefill_batch_size - while batch_size >= 1: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = math.floor(batch_size / PREFILL_BATCH_BUCKET_SIZE) + seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF - seq_len = max_input_length - while seq_len >= PAD_SEQUENCE_TO_MULTIPLE_OF: + i = 0 + while seq_len <= max_input_length: PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len = math.floor(seq_len/2) + seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) + i += 1 - if PREFILL_WARMUP_SEQLEN_LIST[-1] > PAD_SEQUENCE_TO_MULTIPLE_OF: - PREFILL_WARMUP_SEQLEN_LIST.append(PAD_SEQUENCE_TO_MULTIPLE_OF) + if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) #Prefill and decode warmup - prefill_batch = None - PREFILL_WARMUP_BATCH_SIZE_LIST.sort() - PREFILL_WARMUP_SEQLEN_LIST.sort() - try: - for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : + for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE): + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) for seq_len in PREFILL_WARMUP_SEQLEN_LIST : batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) _, prefill_batch, _ = self.generate_token([batch], is_warmup)