From 524e06066b4ead22f7ceb20e1f0323b425facd67 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Feb 2024 09:22:38 +0000 Subject: [PATCH] Small cleanup. Using a single `os.getenv` statement instead of multiple. Should make truthful values easier to catch In the end didn't move towards full CLI because modifying globals in Python is error prone (depends on code import order). Added an error when mamba is launched with TP. --- server/text_generation_server/models/flash_causal_lm.py | 4 ++-- server/text_generation_server/models/globals.py | 3 +++ server/text_generation_server/models/mamba.py | 8 +++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e04a9719..886fe486 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import ( BLOCK_SIZE, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.models.globals import MEM_POOL +from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -793,7 +793,7 @@ class FlashCausalLM(Model): self.device, ) - if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if ENABLE_CUDA_GRAPHS: try: logger.info("Experimental support for Cuda Graphs is enabled") # Warmup cuda graphs diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b0dca376..3b8a70bc 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,3 +1,6 @@ import torch +import os MEM_POOL = torch.cuda.graph_pool_handle() +# This is overridden by the cli +ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"} diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 8f18e475..868db6aa 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -13,7 +13,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.models.globals import MEM_POOL +from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL import time from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams from text_generation_server.models import Model @@ -377,7 +377,9 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.process_group, _rank, _world_size = initialize_torch_distributed() + self.process_group, _rank, world_size = initialize_torch_distributed() + if world_size > 1: + raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") self.cuda_graphs = {} if torch.cuda.is_available(): device = torch.device("cuda") @@ -427,7 +429,7 @@ class Mamba(Model): def warmup(self, batch) -> Optional[int]: # TODO: implement warmup for Mamba if needed - if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if ENABLE_CUDA_GRAPHS: if self.speculate is None or self.speculate == 0: try: logger.info("Experimental support for Cuda Graphs is enabled")