Small cleanup.

Using a single `os.getenv` statement instead of multiple.
Should make truthful values easier to catch

In the end didn't move towards full CLI because modifying globals in
Python is error prone (depends on code import order).

Added an error when mamba is launched with TP.
This commit is contained in:
Nicolas Patry 2024-02-14 09:22:38 +00:00
parent d6b0fb9e25
commit 524e06066b
3 changed files with 10 additions and 5 deletions

View File

@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import (
BLOCK_SIZE,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL
from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
@ -793,7 +793,7 @@ class FlashCausalLM(Model):
self.device,
)
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
if ENABLE_CUDA_GRAPHS:
try:
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs

View File

@ -1,3 +1,6 @@
import torch
import os
MEM_POOL = torch.cuda.graph_pool_handle()
# This is overridden by the cli
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}

View File

@ -13,7 +13,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.models.globals import MEM_POOL
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
from text_generation_server.models import Model
@ -377,7 +377,9 @@ class Mamba(Model):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
self.process_group, _rank, world_size = initialize_torch_distributed()
if world_size > 1:
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
self.cuda_graphs = {}
if torch.cuda.is_available():
device = torch.device("cuda")
@ -427,7 +429,7 @@ class Mamba(Model):
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
if ENABLE_CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0:
try:
logger.info("Experimental support for Cuda Graphs is enabled")