From cd3c28cfe77a052eb403096cae0c8309ac49a923 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 May 2024 16:03:15 +0000 Subject: [PATCH] fix bug --- docs/source/basic_tutorials/monitoring.md | 2 +- server/text_generation_server/models/flash_causal_lm.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/basic_tutorials/monitoring.md b/docs/source/basic_tutorials/monitoring.md index a24cf902..d6e50cfd 100644 --- a/docs/source/basic_tutorials/monitoring.md +++ b/docs/source/basic_tutorials/monitoring.md @@ -72,4 +72,4 @@ Once Prometheus data source is configured, we can finally create our dashboard! Community contributed dashboard templates are also available, for example [here](https://grafana.com/grafana/dashboards/19831-text-generation-inference-dashboard/) or [here](https://grafana.com/grafana/dashboards/20246-text-generation-inference/). -Load your dashboard configuration, and your TGI dashboard should be ready to go! \ No newline at end of file +Load your dashboard configuration, and your TGI dashboard should be ready to go! diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 333efe33..e18b885f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -827,7 +827,7 @@ class FlashCausalLM(Model): self.device, ) - if SYSTEM == "rocm": + if SYSTEM == "rocm" and self.speculate is None or self.speculate == 0: if ( os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" @@ -875,7 +875,11 @@ class FlashCausalLM(Model): logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") # Warmup cuda graphs for bs in CUDA_GRAPHS: - if self.speculate is None or self.speculate + 1 <= bs: + if ( + self.speculate is None + or self.speculate == 0 + or self.speculate + 1 <= bs + ): self.cuda_graph_warmup(bs, max_s, max_bt) except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed")