mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Small cleanup.
Using a single `os.getenv` statement instead of multiple. Should make truthful values easier to catch In the end didn't move towards full CLI because modifying globals in Python is error prone (depends on code import order). Added an error when mamba is launched with TP.
This commit is contained in:
parent
d6b0fb9e25
commit
524e06066b
@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import (
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.globals import MEM_POOL
|
from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
@ -793,7 +793,7 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
|
if ENABLE_CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
|
# This is overridden by the cli
|
||||||
|
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
||||||
|
@ -13,7 +13,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import MEM_POOL
|
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
|
||||||
import time
|
import time
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
|
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -377,7 +377,9 @@ class Mamba(Model):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
self.process_group, _rank, world_size = initialize_torch_distributed()
|
||||||
|
if world_size > 1:
|
||||||
|
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -427,7 +429,7 @@ class Mamba(Model):
|
|||||||
|
|
||||||
def warmup(self, batch) -> Optional[int]:
|
def warmup(self, batch) -> Optional[int]:
|
||||||
# TODO: implement warmup for Mamba if needed
|
# TODO: implement warmup for Mamba if needed
|
||||||
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
|
if ENABLE_CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate == 0:
|
if self.speculate is None or self.speculate == 0:
|
||||||
try:
|
try:
|
||||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||||
|
Loading…
Reference in New Issue
Block a user