From 9dbaa176fd5e70b3aab6b68b50b56aead8cd4621 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 17 Jun 2025 21:13:13 +0000 Subject: [PATCH] Add log_master & VLM cases --- .../text_generation_server/models/flash_causal_lm.py | 9 +++++---- .../models/flash_vlm_causal_lm.py | 3 +++ .../text_generation_server/models/mllama_causal_lm.py | 10 ++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 8de73aea..e380ed53 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1754,10 +1754,11 @@ class FlashCausalLM(Model): total_batch_seq = 0.001 total_mem = 0 available_mem = prompt_available_memory - logger.info( + msg = ( f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n" f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n" ) + log_master(logger.info, msg) for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue @@ -1784,7 +1785,7 @@ class FlashCausalLM(Model): total_mem += used_mem total_batch_seq += batch_seq - logger.info("Prefill warmup successful.\n") + log_master(logger.info, "Prefill warmup successful.\n") def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -1797,7 +1798,7 @@ class FlashCausalLM(Model): total_batch_seq = 0.001 total_mem = 0 available_mem = free_mem - self.mem_reserved - logger.info(f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n") + log_master(logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n") for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue @@ -1822,7 +1823,7 @@ class FlashCausalLM(Model): total_mem += used_mem total_batch_seq += batch_seq - logger.info("Decode warmup successful.\n") + log_master(logger.info, "Decode warmup successful.\n") log_master( logger.info, diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 5bd2292e..fe28d067 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -822,6 +822,7 @@ class FlashVlmCausalLM(FlashCausalLM): total_batch_seq = 0.001 total_mem = 0 available_mem = decode_available_memory + log_master(logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n") for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue @@ -847,6 +848,8 @@ class FlashVlmCausalLM(FlashCausalLM): total_mem += used_mem total_batch_seq += batch_seq + log_master(logger.info, "Decode warmup successful.\n") + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 1be36d09..ec9e149b 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -398,6 +398,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): total_batch_seq = 0.001 total_mem = 0 available_mem = prompt_available_memory + msg = ( + f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n" + f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n" + ) + log_master(logger.info, msg) for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue @@ -424,6 +429,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): total_mem += used_mem total_batch_seq += batch_seq + log_master(logger.info, "Prefill warmup successful.\n") + def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -435,6 +442,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): total_batch_seq = 0.001 total_mem = 0 available_mem = free_mem - self.mem_reserved + log_master(logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n") for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue @@ -459,6 +467,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): total_mem += used_mem total_batch_seq += batch_seq + log_master(logger.info, "Decode warmup successful.\n") + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",