From 686b56a0c009b4fbd09a4254eaafe003856f88d1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Feb 2024 15:30:07 +0100 Subject: [PATCH] Small cleanup. (#1560) Using a single `os.getenv` statement instead of multiple. Should make truthful values easier to catch In the end didn't move towards full CLI because modifying globals in Python is error prone (depends on code import order). Added an error when mamba is launched with TP. # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/models/flash_causal_lm.py | 4 ++-- server/text_generation_server/models/globals.py | 3 +++ server/text_generation_server/models/mamba.py | 8 +++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e04a97198..886fe4869 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import ( BLOCK_SIZE, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.models.globals import MEM_POOL +from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -793,7 +793,7 @@ class FlashCausalLM(Model): self.device, ) - if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if ENABLE_CUDA_GRAPHS: try: logger.info("Experimental support for Cuda Graphs is enabled") # Warmup cuda graphs diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b0dca3769..3b8a70bca 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,3 +1,6 @@ import torch +import os MEM_POOL = torch.cuda.graph_pool_handle() +# This is overridden by the cli +ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"} diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 8f18e4752..868db6aaa 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -13,7 +13,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.models.globals import MEM_POOL +from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL import time from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams from text_generation_server.models import Model @@ -377,7 +377,9 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.process_group, _rank, _world_size = initialize_torch_distributed() + self.process_group, _rank, world_size = initialize_torch_distributed() + if world_size > 1: + raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") self.cuda_graphs = {} if torch.cuda.is_available(): device = torch.device("cuda") @@ -427,7 +429,7 @@ class Mamba(Model): def warmup(self, batch) -> Optional[int]: # TODO: implement warmup for Mamba if needed - if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if ENABLE_CUDA_GRAPHS: if self.speculate is None or self.speculate == 0: try: logger.info("Experimental support for Cuda Graphs is enabled")