From 564c9e1cc04139461dc4cbc387eec52059d2b793 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Mon, 16 Jun 2025 21:07:44 +0000 Subject: [PATCH] Flash causal LM case --- .../text_generation_server/models/flash_causal_lm.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 9883a73f..8de73aea 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1702,6 +1702,7 @@ class FlashCausalLM(Model): f"{dim}:{seq_len} " f"bypass:{bypass} " f"free_mem:{free_mem}" + ", this may take a while..." ) log_master(logger.info, msg) @@ -1753,6 +1754,10 @@ class FlashCausalLM(Model): total_batch_seq = 0.001 total_mem = 0 available_mem = prompt_available_memory + logger.info( + f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n" + f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n" + ) for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue @@ -1779,6 +1784,8 @@ class FlashCausalLM(Model): total_mem += used_mem total_batch_seq += batch_seq + logger.info("Prefill warmup successful.\n") + def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -1790,6 +1797,7 @@ class FlashCausalLM(Model): total_batch_seq = 0.001 total_mem = 0 available_mem = free_mem - self.mem_reserved + logger.info(f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n") for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue @@ -1814,6 +1822,8 @@ class FlashCausalLM(Model): total_mem += used_mem total_batch_seq += batch_seq + logger.info("Decode warmup successful.\n") + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",