From 4586325a34e8d9bd78b7431c4d76bc783fab12b3 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Tue, 26 Nov 2024 08:55:42 +0000 Subject: [PATCH] Fix the starCode warmup issue Signed-off-by: yuanwu --- .../models/causal_lm.py | 124 +++++++----------- 1 file changed, 50 insertions(+), 74 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 8b3df5e0..46a8f231 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -61,28 +61,13 @@ LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2)) - -PREFILL_WARMUP_BATCH_SIZE_LIST = [] -PREFILL_WARMUP_SEQLEN_LIST = [] -DECODE_WARMUP_BATCH_SIZE_LIST = [] - - def torch_compile_for_eager(func): if LAZY_MODE == 1: return func return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True}) - -def round_up(warmup_list:list, num) : - i = 0 - if len(warmup_list) == 0: - return num - - for i in warmup_list: - if num <= i : - break - return i - +def round_up(number, k): + return (number + k - 1) // k * k def to_tensor_indices(indices, device): return torch.tensor(indices, dtype=torch.long, device=device) @@ -372,14 +357,13 @@ class CausalLMBatch(Batch): self.set_tensor_groups(dst_tensors) @classmethod - def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "CausalLMBatch": + def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch": if not all(b.past_key_values is not None for b in batches): raise ValueError("KV cache not allocated! Cannot recombine before prefill!") total_requests = sum(len(b) for b in batches) new_bs = total_requests - if is_warmup is False : - new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) + new_bs = round_up(total_requests, BATCH_BUCKET_SIZE) batch_id = batches[0].batch_id device = batches[0].input_ids.device @@ -481,7 +465,6 @@ class CausalLMBatch(Batch): tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, - is_warmup: bool = False, ) -> "CausalLMBatch": dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] @@ -503,7 +486,7 @@ class CausalLMBatch(Batch): # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence # was filtered out - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) + new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) missing_inputs = new_bs - len(inputs) dummy_inputs = ["?"] * missing_inputs parameters = [r.parameters for r in pb.requests] @@ -533,7 +516,7 @@ class CausalLMBatch(Batch): left_padding = max_input_length - input_len if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) + rounded_seq_len = round_up(input_len + 1, PREFILL_BATCH_BUCKET_SIZE) if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 else: @@ -593,8 +576,8 @@ class CausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup: bool = False) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id, is_warmup) + def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch": + return cls.recombine(batches, pad_token_id) def __len__(self): return len(self.requests) @@ -895,7 +878,7 @@ class CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batches: List[CausalLMBatch], is_warmup: bool = False + self, batches: List[CausalLMBatch] ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: start = time.time_ns() # Results @@ -979,7 +962,7 @@ class CausalLM(Model): # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) else: batch = batches[0] @@ -987,12 +970,12 @@ class CausalLM(Model): # Check if we need to do any bookkeeping first if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) + batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) scenario = 'PREFILL' if prefill else 'GENERATE' - if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: + if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs: self.model.clear_cache() - self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) + self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) dbg_trace( scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') assert batch.right_padding > 0, 'No more room for next token!' @@ -1194,100 +1177,93 @@ class CausalLM(Model): for i in range(len(batch.requests) - batch_size): batch.requests.pop() - return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True) + return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) def warmup(self, request) -> None: - is_warmup = True MAX_TOTAL_TOKENS = request.max_total_tokens MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens - batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup) + batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device) try: # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch], is_warmup) + _, prefill_batch, _ = self.generate_token([batch]) except: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) - + del prefill_batch #warmup decode batch size max_prefill_batch_size = batch.input_ids.shape[0] + del batch max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) + max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) + decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)] + decode_batch_size_list.append(max_decode_batch_size) + decode_batch_size_list.sort(reverse=True) + self.limit_hpu_graph = True try: - for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE): + for batch_size in decode_batch_size_list: batches= [] iters = math.floor(batch_size/max_prefill_batch_size) - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) for i in range(iters): batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) + _, prefill_batch, _ = self.generate_token([batch]) batches.append(prefill_batch) if batch_size % max_prefill_batch_size != 0: batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(batch) + _, prefill_batch, _ = self.generate_token([batch]) + batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches, is_warmup) + _, decode_batch, _ = self.generate_token(batches) + del decode_batch + batches.clear() except: - DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1) - self.model.clear_cache() - if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0: - logger.warning( - f"Not enough memory to warmup all batch size of decode." - f"You need to decrease `--max-batch-total-tokens`" - ) - else: raise RuntimeError( - f"Not enough memory to warmup decode batch_size({max_decode_batch_size})." + f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." f"You need to decrease `--max-batch-total-tokens`" ) - DECODE_WARMUP_BATCH_SIZE_LIST.sort() + decode_batch_size_list.sort() + MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] mem_stats = get_hpu_memory_stats(self.device) logger.info( f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Decode batch size list:{decode_batch_size_list}\n" f"Memory stats: {mem_stats} " ) # Warmup prefill batch_size max_input_length = request.max_input_length - max_prefill_batch_size = batch.input_ids.shape[0] - seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF - - i = 0 - while seq_len <= max_input_length: - PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) - i += 1 - - if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: - PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) - + prefill_batch_size_list = [] + prefill_seqlen_list = [] #Prefill and decode warmup try: - for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE): - PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) - for seq_len in PREFILL_WARMUP_SEQLEN_LIST : - batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) + for batch_size in range(max_prefill_batch_size, 0, -PREFILL_BATCH_BUCKET_SIZE): + prefill_batch_size_list.append(batch_size) + for seq_len in range(max_input_length, 0, -PAD_SEQUENCE_TO_MULTIPLE_OF): + prefill_seqlen_list.append(seq_len) + batch = self.generate_warmup_batch(request, seq_len, batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + del batch + del prefill_batch except: raise RuntimeError( f"Not enough memory to run following prefill batch_size." - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" + f"Prefill batch size list:{prefill_batch_size_list}" + f"Prefill sequence length list:{prefill_seqlen_list}" f"You need to decrease `--max-batch-prefill-tokens`" ) - + prefill_batch_size_list.sort() + prefill_seqlen_list.sort() limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" if limit_hpu_graph == False: mem_stats = get_hpu_memory_stats(self.device) logger.info( f"\nFollowing prefill and decode warmup successfully.\n" - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" + f"Prefill batch size list:{prefill_batch_size_list}\n" + f"Prefill sequence length list:{prefill_seqlen_list}\n" f"Memory stats: {mem_stats} " )