From 9d85ac948549e1d19cb0a9705c5a066fd1ca8918 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 30 Mar 2025 20:20:09 -0700 Subject: [PATCH] LLM warmup logic Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 208 +++++++++++++----- 1 file changed, 156 insertions(+), 52 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index a4d58596..ed8b658a 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -71,6 +71,7 @@ import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools from vllm_hpu_extension.ops import batch2block, block2batch +from vllm_hpu_extension.bucketing import HPUBucketingContext tracer = trace.get_tracer(__name__) @@ -89,7 +90,7 @@ def get_sliding_windows() -> int: def prepare_for_decode( - dtype, use_contiguous_pa, device, slot, block_tables, batch_size + dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx ): # Prepare values if we need to continue decoding # need for HPUPagedAttentionMetadata preparation @@ -120,8 +121,10 @@ def prepare_for_decode( assert len(block_list) == len(block_usage) if use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) - # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - # block_bucket_size) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -131,6 +134,10 @@ def prepare_for_decode( block_usage = gather_list(block_usage, indices, 1) else: block_bucket_size = len(block_list) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) block_list = pad_list(block_list, block_bucket_size, 0) block_groups = pad_list(block_groups, block_bucket_size, -1) block_usage = pad_list(block_usage, block_bucket_size, 1) @@ -835,15 +842,16 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = ( - batch.slot_indices + cumulative_slots + index = torch.tensor( + list(range(start_index, end_index)), device=batch.input_ids.device ) - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor - cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor - - # Copy over adapter indices + input_ids.index_copy_(0, index, batch.input_ids) + position_ids.index_copy_(0, index, batch.position_ids) + slot_indices.index_copy_( + 0, index, batch.slot_indices + cumulative_slots + ) + input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor) + cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor) adapter_start_index = cumulative_adapter_indices_size adapter_end_index = ( cumulative_adapter_indices_size @@ -951,22 +959,34 @@ class FlashCausalLMBatch(Batch): hpu_attn_meta=None, ) - def prepare_for_decode(self, dtype, use_contiguous_pa): + def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 block_tables = [] for i, bt in enumerate(self.block_tables): block_tables.append(bt[0 : block_num[i]]) + if bucketing_ctx is not None: + padded_bs = bucketing_ctx.get_padded_decode_batch_size( + self.input_ids.shape[0] + ) + else: + padded_bs = self.input_ids.shape[0] + slots = self.slots[self.slot_indices] + extra_pad = padded_bs - self.input_ids.shape[0] + if extra_pad != 0: + slots = F.pad(slots, (0, extra_pad), value=0) + block_tables.extend([[0]] * extra_pad) self.hpu_attn_meta = prepare_for_decode( dtype, use_contiguous_pa, self.block_tables_tensor.device, - self.slots[self.slot_indices], + slots, block_tables, - self.input_ids.size(0), + padded_bs, + bucketing_ctx, ) - def prepare_for_prefill(self): + def prepare_for_prefill(self, max_padded_input_len): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything @@ -980,7 +1000,7 @@ class FlashCausalLMBatch(Batch): # the right logit position input_ids_padded_length = [] # need extra pad to match warmup seq - extra_pad = 0 + extra_pad = max_padded_input_len - self.max_input_length if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1355,9 +1375,9 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - + self.bucketing_ctx = None if htorch.utils.internal.is_lazy(): - htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False) + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" @@ -1479,9 +1499,31 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + self.bucketing_ctx = HPUBucketingContext( + os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO + os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO + BLOCK_SIZE, + num_blocks * BLOCK_SIZE, + ) + self.bucketing_ctx.num_hpu_blocks = num_blocks + warmup_times = 3 + self.bucketing_ctx.generate_prompt_buckets() + for i, (batch_size, seq_len) in enumerate( + reversed(self.bucketing_ctx.prompt_buckets) + ): + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size) + self.bucketing_ctx.generate_decode_buckets(num_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num) + synchronize(self.device) return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): + logger.info(f"warmup prefill seq {prompt_len} bs {bs}") input_ids = torch.zeros( prompt_len, dtype=torch.int64, device=self.device ).repeat(bs) @@ -1527,25 +1569,32 @@ class FlashCausalLM(Model): hpu_attention_meta=None, ) - def warmup_decode(self, bs: int, block_num: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) - block_tables = torch.arange( - start=1, end=block_num + 1, dtype=torch.int32, device=self.device - ).reshape(bs, -1) + def warmup_decode(self, batch_size: int, block_num: int): + logger.info(f"warmup decode bs {batch_size} block_num {block_num}") + input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) + position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] slots = [] - past_len = ( - len(block_tables[0]) * BLOCK_SIZE - 1 - ) # for decode, we only need to pass the past token + start_idx = 0 + # fetch the last blocked to warmup block num - for i in range(bs): - slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] slots = torch.tensor(slots, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) - cache_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * past_len + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( @@ -1553,20 +1602,16 @@ class FlashCausalLM(Model): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) - block_num = cache_lengths_tensor // BLOCK_SIZE + 1 - block_tables_valid = [] - for i, bt in enumerate(block_tables.tolist()): - block_tables_valid.append(bt[0 : block_num[i]]) hpu_attention_meta = prepare_for_decode( self.dtype, self.use_contiguous_pa, self.device, slots, - block_tables_valid, - bs, + block_tables, + batch_size, + bucketing_ctx=None, ) - # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -1651,19 +1696,69 @@ class FlashCausalLM(Model): # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - kwargs = {} - if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + orig_bs = input_lengths.shape[0] + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + position_ids = F.pad( + position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1677,7 +1772,9 @@ class FlashCausalLM(Model): hpu_attention_meta=batch.hpu_attn_meta, **kwargs, ) - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) @tracer.start_as_current_span("generate_token") def generate_token( @@ -1690,9 +1787,16 @@ class FlashCausalLM(Model): start = time.time_ns() prefill = batch.prefilling if prefill: - batch.prepare_for_prefill() + if self.bucketing_ctx is not None: + batch.prepare_for_prefill( + self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length) + ) + else: + batch.prepare_for_prefill(batch.max_input_length) else: - batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) + batch.prepare_for_decode( + self.dtype, self.use_contiguous_pa, self.bucketing_ctx + ) prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta