Fix the starCode warmup issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-11-26 08:55:42 +00:00
parent b83419a769
commit 4586325a34

View File

@ -61,28 +61,13 @@ LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8)) BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2)) PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
PREFILL_WARMUP_BATCH_SIZE_LIST = []
PREFILL_WARMUP_SEQLEN_LIST = []
DECODE_WARMUP_BATCH_SIZE_LIST = []
def torch_compile_for_eager(func): def torch_compile_for_eager(func):
if LAZY_MODE == 1: if LAZY_MODE == 1:
return func return func
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True}) return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})
def round_up(number, k):
def round_up(warmup_list:list, num) : return (number + k - 1) // k * k
i = 0
if len(warmup_list) == 0:
return num
for i in warmup_list:
if num <= i :
break
return i
def to_tensor_indices(indices, device): def to_tensor_indices(indices, device):
return torch.tensor(indices, dtype=torch.long, device=device) return torch.tensor(indices, dtype=torch.long, device=device)
@ -372,14 +357,13 @@ class CausalLMBatch(Batch):
self.set_tensor_groups(dst_tensors) self.set_tensor_groups(dst_tensors)
@classmethod @classmethod
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "CausalLMBatch": def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
if not all(b.past_key_values is not None for b in batches): if not all(b.past_key_values is not None for b in batches):
raise ValueError("KV cache not allocated! Cannot recombine before prefill!") raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
total_requests = sum(len(b) for b in batches) total_requests = sum(len(b) for b in batches)
new_bs = total_requests new_bs = total_requests
if is_warmup is False : new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
batch_id = batches[0].batch_id batch_id = batches[0].batch_id
device = batches[0].input_ids.device device = batches[0].input_ids.device
@ -481,7 +465,6 @@ class CausalLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
is_warmup: bool = False,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
@ -503,7 +486,7 @@ class CausalLMBatch(Batch):
# TODO: by tokenizing all inputs at once we loose information on actual input lengths # TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence # this means that we cannot shift inputs to the left after a long input sequence
# was filtered out # was filtered out
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
missing_inputs = new_bs - len(inputs) missing_inputs = new_bs - len(inputs)
dummy_inputs = ["?"] * missing_inputs dummy_inputs = ["?"] * missing_inputs
parameters = [r.parameters for r in pb.requests] parameters = [r.parameters for r in pb.requests]
@ -533,7 +516,7 @@ class CausalLMBatch(Batch):
left_padding = max_input_length - input_len left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) rounded_seq_len = round_up(input_len + 1, PREFILL_BATCH_BUCKET_SIZE)
if rounded_seq_len <= max_input_length: if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1 bucket_size = rounded_seq_len - 1
else: else:
@ -593,8 +576,8 @@ class CausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup: bool = False) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id, is_warmup) return cls.recombine(batches, pad_token_id)
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -895,7 +878,7 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batches: List[CausalLMBatch], is_warmup: bool = False self, batches: List[CausalLMBatch]
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
# Results # Results
@ -979,7 +962,7 @@ class CausalLM(Model):
# Stage 2. Prepare new batch for speculative scheduling # Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1: if len(batches) > 1:
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
else: else:
batch = batches[0] batch = batches[0]
@ -987,12 +970,12 @@ class CausalLM(Model):
# Check if we need to do any bookkeeping first # Check if we need to do any bookkeeping first
if not prefill: if not prefill:
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
scenario = 'PREFILL' if prefill else 'GENERATE' scenario = 'PREFILL' if prefill else 'GENERATE'
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
self.model.clear_cache() self.model.clear_cache()
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
dbg_trace( dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!' assert batch.right_padding > 0, 'No more room for next token!'
@ -1194,100 +1177,93 @@ class CausalLM(Model):
for i in range(len(batch.requests) - batch_size): for i in range(len(batch.requests) - batch_size):
batch.requests.pop() batch.requests.pop()
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True) return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
def warmup(self, request) -> None: def warmup(self, request) -> None:
is_warmup = True
MAX_TOTAL_TOKENS = request.max_total_tokens MAX_TOTAL_TOKENS = request.max_total_tokens
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup) batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch])
except: except:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) )
del prefill_batch
#warmup decode batch size #warmup decode batch size
max_prefill_batch_size = batch.input_ids.shape[0] max_prefill_batch_size = batch.input_ids.shape[0]
del batch
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
decode_batch_size_list.append(max_decode_batch_size)
decode_batch_size_list.sort(reverse=True)
self.limit_hpu_graph = True self.limit_hpu_graph = True
try: try:
for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE): for batch_size in decode_batch_size_list:
batches= [] batches= []
iters = math.floor(batch_size/max_prefill_batch_size) iters = math.floor(batch_size/max_prefill_batch_size)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
for i in range(iters): for i in range(iters):
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch) batches.append(prefill_batch)
if batch_size % max_prefill_batch_size != 0: if batch_size % max_prefill_batch_size != 0:
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup) _, prefill_batch, _ = self.generate_token([batch])
batches.append(batch) batches.append(prefill_batch)
_, decode_batch, _ = self.generate_token(batches, is_warmup) _, decode_batch, _ = self.generate_token(batches)
del decode_batch
batches.clear()
except: except:
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
self.model.clear_cache()
if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0:
logger.warning(
f"Not enough memory to warmup all batch size of decode."
f"You need to decrease `--max-batch-total-tokens`"
)
else:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to warmup decode batch_size({max_decode_batch_size})." f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
f"You need to decrease `--max-batch-total-tokens`" f"You need to decrease `--max-batch-total-tokens`"
) )
DECODE_WARMUP_BATCH_SIZE_LIST.sort() decode_batch_size_list.sort()
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing decode warmup successfully.\n" f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" f"Decode batch size list:{decode_batch_size_list}\n"
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )
# Warmup prefill batch_size # Warmup prefill batch_size
max_input_length = request.max_input_length max_input_length = request.max_input_length
max_prefill_batch_size = batch.input_ids.shape[0] prefill_batch_size_list = []
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF prefill_seqlen_list = []
i = 0
while seq_len <= max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
i += 1
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
#Prefill and decode warmup #Prefill and decode warmup
try: try:
for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE): for batch_size in range(max_prefill_batch_size, 0, -PREFILL_BATCH_BUCKET_SIZE):
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) prefill_batch_size_list.append(batch_size)
for seq_len in PREFILL_WARMUP_SEQLEN_LIST : for seq_len in range(max_input_length, 0, -PAD_SEQUENCE_TO_MULTIPLE_OF):
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) prefill_seqlen_list.append(seq_len)
_, prefill_batch, _ = self.generate_token([batch], is_warmup) batch = self.generate_warmup_batch(request, seq_len, batch_size)
_, prefill_batch, _ = self.generate_token([batch])
del batch
del prefill_batch
except: except:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to run following prefill batch_size." f"Not enough memory to run following prefill batch_size."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" f"Prefill batch size list:{prefill_batch_size_list}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" f"Prefill sequence length list:{prefill_seqlen_list}"
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) )
prefill_batch_size_list.sort()
prefill_seqlen_list.sort()
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
if limit_hpu_graph == False: if limit_hpu_graph == False:
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing prefill and decode warmup successfully.\n" f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" f"Prefill sequence length list:{prefill_seqlen_list}\n"
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )