LLM warmup logic

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-30 20:20:09 -07:00
parent c55a8caea2
commit 9d85ac9485

View File

@ -71,6 +71,7 @@ import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.bucketing import HPUBucketingContext
tracer = trace.get_tracer(__name__)
@ -89,7 +90,7 @@ def get_sliding_windows() -> int:
def prepare_for_decode(
dtype, use_contiguous_pa, device, slot, block_tables, batch_size
dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx
):
# Prepare values if we need to continue decoding
# need for HPUPagedAttentionMetadata preparation
@ -120,8 +121,10 @@ def prepare_for_decode(
assert len(block_list) == len(block_usage)
if use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
# block_bucket_size)
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size
)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
@ -131,6 +134,10 @@ def prepare_for_decode(
block_usage = gather_list(block_usage, indices, 1)
else:
block_bucket_size = len(block_list)
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size
)
block_list = pad_list(block_list, block_bucket_size, 0)
block_groups = pad_list(block_groups, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
@ -835,15 +842,16 @@ class FlashCausalLMBatch(Batch):
)
if not prefilling:
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots
index = torch.tensor(
list(range(start_index, end_index)), device=batch.input_ids.device
)
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Copy over adapter indices
input_ids.index_copy_(0, index, batch.input_ids)
position_ids.index_copy_(0, index, batch.position_ids)
slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots
)
input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor)
cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor)
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
@ -951,22 +959,34 @@ class FlashCausalLMBatch(Batch):
hpu_attn_meta=None,
)
def prepare_for_decode(self, dtype, use_contiguous_pa):
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
block_tables = []
for i, bt in enumerate(self.block_tables):
block_tables.append(bt[0 : block_num[i]])
if bucketing_ctx is not None:
padded_bs = bucketing_ctx.get_padded_decode_batch_size(
self.input_ids.shape[0]
)
else:
padded_bs = self.input_ids.shape[0]
slots = self.slots[self.slot_indices]
extra_pad = padded_bs - self.input_ids.shape[0]
if extra_pad != 0:
slots = F.pad(slots, (0, extra_pad), value=0)
block_tables.extend([[0]] * extra_pad)
self.hpu_attn_meta = prepare_for_decode(
dtype,
use_contiguous_pa,
self.block_tables_tensor.device,
self.slots[self.slot_indices],
slots,
block_tables,
self.input_ids.size(0),
padded_bs,
bucketing_ctx,
)
def prepare_for_prefill(self):
def prepare_for_prefill(self, max_padded_input_len):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
@ -980,7 +1000,7 @@ class FlashCausalLMBatch(Batch):
# the right logit position
input_ids_padded_length = []
# need extra pad to match warmup seq
extra_pad = 0
extra_pad = max_padded_input_len - self.max_input_length
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = []
input_ids = []
@ -1355,9 +1375,9 @@ class FlashCausalLM(Model):
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None
if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False)
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
environment.set_model_config(self.config)
self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
@ -1479,9 +1499,31 @@ class FlashCausalLM(Model):
self.kv_cache_dtype,
self.device,
)
self.bucketing_ctx = HPUBucketingContext(
os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO
os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE,
num_blocks * BLOCK_SIZE,
)
self.bucketing_ctx.num_hpu_blocks = num_blocks
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size)
self.bucketing_ctx.generate_decode_buckets(num_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num)
synchronize(self.device)
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def warmup_prefill(self, prompt_len: int, bs: int):
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
input_ids = torch.zeros(
prompt_len, dtype=torch.int64, device=self.device
).repeat(bs)
@ -1527,25 +1569,32 @@ class FlashCausalLM(Model):
hpu_attention_meta=None,
)
def warmup_decode(self, bs: int, block_num: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
start=1, end=block_num + 1, dtype=torch.int32, device=self.device
).reshape(bs, -1)
def warmup_decode(self, batch_size: int, block_num: int):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
past_len = (
len(block_tables[0]) * BLOCK_SIZE - 1
) # for decode, we only need to pass the past token
start_idx = 0
# fetch the last blocked to warmup block num
for i in range(bs):
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
cache_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
)
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
@ -1553,20 +1602,16 @@ class FlashCausalLM(Model):
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
block_num = cache_lengths_tensor // BLOCK_SIZE + 1
block_tables_valid = []
for i, bt in enumerate(block_tables.tolist()):
block_tables_valid.append(bt[0 : block_num[i]])
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables_valid,
bs,
block_tables,
batch_size,
bucketing_ctx=None,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
@ -1651,19 +1696,69 @@ class FlashCausalLM(Model):
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
if self.bucketing_ctx is not None:
if batch.prefilling:
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
input_lengths.shape[0]
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
value=0,
)
padded_cache_lengths_tensor = F.pad(
cache_lengths_tensor,
(0, padded_bs - orig_bs),
value=0,
)
if cu_seqlen_prefill is not None:
cu_seqlen_prefill = torch.zeros(
padded_bs + 1, device=self.device, dtype=torch.int32
)
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=padded_input_lengths,
cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
input_seq = input_ids.view(orig_bs, -1)
input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
position_ids = F.pad(
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0
)
else:
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -1677,7 +1772,9 @@ class FlashCausalLM(Model):
hpu_attention_meta=batch.hpu_attn_meta,
**kwargs,
)
return logits, speculative_logits
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)
@tracer.start_as_current_span("generate_token")
def generate_token(
@ -1690,9 +1787,16 @@ class FlashCausalLM(Model):
start = time.time_ns()
prefill = batch.prefilling
if prefill:
batch.prepare_for_prefill()
if self.bucketing_ctx is not None:
batch.prepare_for_prefill(
self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length)
)
else:
batch.prepare_for_prefill(batch.max_input_length)
else:
batch.prepare_for_decode(self.dtype, self.use_contiguous_pa)
batch.prepare_for_decode(
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
)
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta