Flash causal LM case

This commit is contained in:
regisss 2025-06-16 21:07:44 +00:00
parent 2ba396c4c1
commit 564c9e1cc0

View File

@ -1702,6 +1702,7 @@ class FlashCausalLM(Model):
f"{dim}:{seq_len} " f"{dim}:{seq_len} "
f"bypass:{bypass} " f"bypass:{bypass} "
f"free_mem:{free_mem}" f"free_mem:{free_mem}"
", this may take a while..."
) )
log_master(logger.info, msg) log_master(logger.info, msg)
@ -1753,6 +1754,10 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = prompt_available_memory available_mem = prompt_available_memory
logger.info(
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
for i, (batch_size, seq_len) in enumerate(buckets): for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens: if batch_size * seq_len > self.max_batch_prefill_tokens:
continue continue
@ -1779,6 +1784,8 @@ class FlashCausalLM(Model):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
logger.info("Prefill warmup successful.\n")
def ordering_function_max_bs(b): def ordering_function_max_bs(b):
return (-b[0], b[1]) return (-b[0], b[1])
@ -1790,6 +1797,7 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = free_mem - self.mem_reserved available_mem = free_mem - self.mem_reserved
logger.info(f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n")
for i, (batch_size, block_num) in enumerate(buckets): for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
@ -1814,6 +1822,8 @@ class FlashCausalLM(Model):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
logger.info("Decode warmup successful.\n")
log_master( log_master(
logger.info, logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",