diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 9883a73f..8de73aea 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1702,6 +1702,7 @@ class FlashCausalLM(Model): f"{dim}:{seq_len} " f"bypass:{bypass} " f"free_mem:{free_mem}" + ", this may take a while..." ) log_master(logger.info, msg) @@ -1753,6 +1754,10 @@ class FlashCausalLM(Model): total_batch_seq = 0.001 total_mem = 0 available_mem = prompt_available_memory + logger.info( + f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n" + f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n" + ) for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue @@ -1779,6 +1784,8 @@ class FlashCausalLM(Model): total_mem += used_mem total_batch_seq += batch_seq + logger.info("Prefill warmup successful.\n") + def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -1790,6 +1797,7 @@ class FlashCausalLM(Model): total_batch_seq = 0.001 total_mem = 0 available_mem = free_mem - self.mem_reserved + logger.info(f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n") for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue @@ -1814,6 +1822,8 @@ class FlashCausalLM(Model): total_mem += used_mem total_batch_seq += batch_seq + logger.info("Decode warmup successful.\n") + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",