From d67633a0c883af855adc248d7954642d15b20290 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Apr 2024 13:01:27 +0000 Subject: [PATCH] Fix disabling. --- server/text_generation_server/models/globals.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 06394e3f..6f554049 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -3,11 +3,12 @@ import os MEM_POOL = torch.cuda.graph_pool_handle() # This is overridden by the cli -cuda_graphs = os.getenv("CUDA_GRAPHS", "1,2,4,8,16,32,64,96,128") -try: - cuda_graphs = [int(item) for item in cuda_graphs.split(",")] -except Exception as e: - raise RuntimeError( - f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" - ) +cuda_graphs = os.getenv("CUDA_GRAPHS") +if cuda_graphs is not None: + try: + cuda_graphs = [int(item) for item in cuda_graphs.split(",")] + except Exception as e: + raise RuntimeError( + f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" + ) CUDA_GRAPHS = cuda_graphs