mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
LLM warmup logic
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
c55a8caea2
commit
9d85ac9485
@ -71,6 +71,7 @@ import vllm_hpu_extension.environment as environment
|
||||
import habana_frameworks.torch as htorch
|
||||
import itertools
|
||||
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||
from vllm_hpu_extension.bucketing import HPUBucketingContext
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -89,7 +90,7 @@ def get_sliding_windows() -> int:
|
||||
|
||||
|
||||
def prepare_for_decode(
|
||||
dtype, use_contiguous_pa, device, slot, block_tables, batch_size
|
||||
dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx
|
||||
):
|
||||
# Prepare values if we need to continue decoding
|
||||
# need for HPUPagedAttentionMetadata preparation
|
||||
@ -120,8 +121,10 @@ def prepare_for_decode(
|
||||
assert len(block_list) == len(block_usage)
|
||||
if use_contiguous_pa:
|
||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
|
||||
# block_bucket_size)
|
||||
if bucketing_ctx is not None:
|
||||
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
|
||||
block_bucket_size
|
||||
)
|
||||
indices: List[Any]
|
||||
indices = [None] * block_bucket_size
|
||||
for i, bid in enumerate(block_list):
|
||||
@ -131,6 +134,10 @@ def prepare_for_decode(
|
||||
block_usage = gather_list(block_usage, indices, 1)
|
||||
else:
|
||||
block_bucket_size = len(block_list)
|
||||
if bucketing_ctx is not None:
|
||||
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
|
||||
block_bucket_size
|
||||
)
|
||||
block_list = pad_list(block_list, block_bucket_size, 0)
|
||||
block_groups = pad_list(block_groups, block_bucket_size, -1)
|
||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
||||
@ -835,15 +842,16 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slot_indices[start_index:end_index] = (
|
||||
batch.slot_indices + cumulative_slots
|
||||
index = torch.tensor(
|
||||
list(range(start_index, end_index)), device=batch.input_ids.device
|
||||
)
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
|
||||
|
||||
# Copy over adapter indices
|
||||
input_ids.index_copy_(0, index, batch.input_ids)
|
||||
position_ids.index_copy_(0, index, batch.position_ids)
|
||||
slot_indices.index_copy_(
|
||||
0, index, batch.slot_indices + cumulative_slots
|
||||
)
|
||||
input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor)
|
||||
cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor)
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
adapter_end_index = (
|
||||
cumulative_adapter_indices_size
|
||||
@ -951,22 +959,34 @@ class FlashCausalLMBatch(Batch):
|
||||
hpu_attn_meta=None,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa):
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
||||
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
|
||||
block_tables = []
|
||||
for i, bt in enumerate(self.block_tables):
|
||||
block_tables.append(bt[0 : block_num[i]])
|
||||
if bucketing_ctx is not None:
|
||||
padded_bs = bucketing_ctx.get_padded_decode_batch_size(
|
||||
self.input_ids.shape[0]
|
||||
)
|
||||
else:
|
||||
padded_bs = self.input_ids.shape[0]
|
||||
slots = self.slots[self.slot_indices]
|
||||
extra_pad = padded_bs - self.input_ids.shape[0]
|
||||
if extra_pad != 0:
|
||||
slots = F.pad(slots, (0, extra_pad), value=0)
|
||||
block_tables.extend([[0]] * extra_pad)
|
||||
|
||||
self.hpu_attn_meta = prepare_for_decode(
|
||||
dtype,
|
||||
use_contiguous_pa,
|
||||
self.block_tables_tensor.device,
|
||||
self.slots[self.slot_indices],
|
||||
slots,
|
||||
block_tables,
|
||||
self.input_ids.size(0),
|
||||
padded_bs,
|
||||
bucketing_ctx,
|
||||
)
|
||||
|
||||
def prepare_for_prefill(self):
|
||||
def prepare_for_prefill(self, max_padded_input_len):
|
||||
# Prepare values if we need to continue prefilling
|
||||
# Speculation must be ignored while we prefill even with chunking
|
||||
# it simplifies everything
|
||||
@ -980,7 +1000,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# the right logit position
|
||||
input_ids_padded_length = []
|
||||
# need extra pad to match warmup seq
|
||||
extra_pad = 0
|
||||
extra_pad = max_padded_input_len - self.max_input_length
|
||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||
input_ids_padded_length = []
|
||||
input_ids = []
|
||||
@ -1355,9 +1375,9 @@ class FlashCausalLM(Model):
|
||||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
|
||||
self.bucketing_ctx = None
|
||||
if htorch.utils.internal.is_lazy():
|
||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False)
|
||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
environment.set_model_config(self.config)
|
||||
self.use_contiguous_pa = (
|
||||
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
|
||||
@ -1479,9 +1499,31 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
self.bucketing_ctx = HPUBucketingContext(
|
||||
os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO
|
||||
os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO
|
||||
BLOCK_SIZE,
|
||||
num_blocks * BLOCK_SIZE,
|
||||
)
|
||||
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
||||
warmup_times = 3
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size)
|
||||
self.bucketing_ctx.generate_decode_buckets(num_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num)
|
||||
synchronize(self.device)
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
||||
input_ids = torch.zeros(
|
||||
prompt_len, dtype=torch.int64, device=self.device
|
||||
).repeat(bs)
|
||||
@ -1527,25 +1569,32 @@ class FlashCausalLM(Model):
|
||||
hpu_attention_meta=None,
|
||||
)
|
||||
|
||||
def warmup_decode(self, bs: int, block_num: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
block_tables = torch.arange(
|
||||
start=1, end=block_num + 1, dtype=torch.int32, device=self.device
|
||||
).reshape(bs, -1)
|
||||
def warmup_decode(self, batch_size: int, block_num: int):
|
||||
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
||||
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||
blocks[0] += block_num % batch_size
|
||||
past_len = []
|
||||
block_tables = []
|
||||
slots = []
|
||||
past_len = (
|
||||
len(block_tables[0]) * BLOCK_SIZE - 1
|
||||
) # for decode, we only need to pass the past token
|
||||
start_idx = 0
|
||||
|
||||
# fetch the last blocked to warmup block num
|
||||
for i in range(bs):
|
||||
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
|
||||
for i in range(batch_size):
|
||||
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.tensor(
|
||||
past_len, dtype=torch.int32, device=self.device
|
||||
)
|
||||
cu_seqlen_prefill = torch.zeros(
|
||||
batch_size + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||
|
||||
seqlen = Seqlen(
|
||||
@ -1553,20 +1602,16 @@ class FlashCausalLM(Model):
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
block_num = cache_lengths_tensor // BLOCK_SIZE + 1
|
||||
block_tables_valid = []
|
||||
for i, bt in enumerate(block_tables.tolist()):
|
||||
block_tables_valid.append(bt[0 : block_num[i]])
|
||||
|
||||
hpu_attention_meta = prepare_for_decode(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
self.device,
|
||||
slots,
|
||||
block_tables_valid,
|
||||
bs,
|
||||
block_tables,
|
||||
batch_size,
|
||||
bucketing_ctx=None,
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -1651,19 +1696,69 @@ class FlashCausalLM(Model):
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
slots = slots_pad
|
||||
if self.bucketing_ctx is not None:
|
||||
if batch.prefilling:
|
||||
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
||||
input_lengths.shape[0]
|
||||
)
|
||||
else:
|
||||
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||
input_lengths.shape[0]
|
||||
)
|
||||
else:
|
||||
padded_bs = input_lengths.shape[0]
|
||||
orig_bs = input_lengths.shape[0]
|
||||
if padded_bs != input_lengths.shape[0]:
|
||||
orig_bs = input_lengths.shape[0]
|
||||
padded_input_lengths = F.pad(
|
||||
input_lengths,
|
||||
(0, padded_bs - orig_bs),
|
||||
value=0,
|
||||
)
|
||||
padded_cache_lengths_tensor = F.pad(
|
||||
cache_lengths_tensor,
|
||||
(0, padded_bs - orig_bs),
|
||||
value=0,
|
||||
)
|
||||
if cu_seqlen_prefill is not None:
|
||||
cu_seqlen_prefill = torch.zeros(
|
||||
padded_bs + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||
seqlen = Seqlen(
|
||||
input_lengths=padded_input_lengths,
|
||||
cache_lengths=padded_cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
input_seq = input_ids.view(orig_bs, -1)
|
||||
input_ids = F.pad(
|
||||
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||
)
|
||||
position_ids = F.pad(
|
||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
||||
)
|
||||
slots = F.pad(
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
lm_head_indices = F.pad(
|
||||
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
||||
)
|
||||
else:
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1677,7 +1772,9 @@ class FlashCausalLM(Model):
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
**kwargs,
|
||||
)
|
||||
return logits, speculative_logits
|
||||
return logits[:orig_bs], (
|
||||
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
@ -1690,9 +1787,16 @@ class FlashCausalLM(Model):
|
||||
start = time.time_ns()
|
||||
prefill = batch.prefilling
|
||||
if prefill:
|
||||
batch.prepare_for_prefill()
|
||||
if self.bucketing_ctx is not None:
|
||||
batch.prepare_for_prefill(
|
||||
self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length)
|
||||
)
|
||||
else:
|
||||
batch.prepare_for_prefill(batch.max_input_length)
|
||||
else:
|
||||
batch.prepare_for_decode(self.dtype, self.use_contiguous_pa)
|
||||
batch.prepare_for_decode(
|
||||
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
||||
)
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
adapter_meta = batch.adapter_meta
|
||||
|
Loading…
Reference in New Issue
Block a user