diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 69b3fe67..efeda08d 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -319,7 +319,7 @@ def launcher(event_loop): env = { "LOG_LEVEL": "info,text_generation_router=debug", - "ENABLE_CUDA_GRAPHS": "True", + "ENABLE_CUDA_GRAPHS": "true", } if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 21ed4f6c..e68a2100 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -777,7 +777,7 @@ class FlashCausalLM(Model): self.device, ) - if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if os.getenv("ENABLE_CUDA_GRAPHS", "false") == "true": try: # Warmup cuda graphs for all power of twos until 64 for i in range(6):