Small cleanup.

Using a single `os.getenv` statement instead of multiple.
Should make truthful values easier to catch

In the end didn't move towards full CLI because modifying globals in
Python is error prone (depends on code import order).

Added an error when mamba is launched with TP.
This commit is contained in:
Nicolas Patry 2024-02-14 09:22:38 +00:00
parent d6b0fb9e25
commit 524e06066b
3 changed files with 10 additions and 5 deletions

View File

@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import (
BLOCK_SIZE, BLOCK_SIZE,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
@ -793,7 +793,7 @@ class FlashCausalLM(Model):
self.device, self.device,
) )
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": if ENABLE_CUDA_GRAPHS:
try: try:
logger.info("Experimental support for Cuda Graphs is enabled") logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs # Warmup cuda graphs

View File

@ -1,3 +1,6 @@
import torch import torch
import os
MEM_POOL = torch.cuda.graph_pool_handle() MEM_POOL = torch.cuda.graph_pool_handle()
# This is overridden by the cli
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}

View File

@ -13,7 +13,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.models.globals import MEM_POOL from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
import time import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
from text_generation_server.models import Model from text_generation_server.models import Model
@ -377,7 +377,9 @@ class Mamba(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, _rank, _world_size = initialize_torch_distributed() self.process_group, _rank, world_size = initialize_torch_distributed()
if world_size > 1:
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
self.cuda_graphs = {} self.cuda_graphs = {}
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -427,7 +429,7 @@ class Mamba(Model):
def warmup(self, batch) -> Optional[int]: def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed # TODO: implement warmup for Mamba if needed
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": if ENABLE_CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0: if self.speculate is None or self.speculate == 0:
try: try:
logger.info("Experimental support for Cuda Graphs is enabled") logger.info("Experimental support for Cuda Graphs is enabled")