mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
refactor model_id, make tunableop default
This commit is contained in:
parent
df0a453693
commit
a040a59068
@ -46,7 +46,6 @@ class BLOOMSharded(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -31,6 +31,7 @@ from text_generation_server.models.cache_manager import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||||
|
import text_generation_server.models.globals as tgi_globals
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
@ -827,11 +828,14 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
if os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if (
|
||||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||||
|
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||||
|
):
|
||||||
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
|
||||||
torch.cuda.tunable.tuning_enable(True)
|
torch.cuda.tunable.tuning_enable(True)
|
||||||
|
|
||||||
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
|
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
|
||||||
tuning_sequences = [
|
tuning_sequences = [
|
||||||
int(val)
|
int(val)
|
||||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||||
@ -841,11 +845,11 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
HUGGINGFACE_HUB_CACHE,
|
HUGGINGFACE_HUB_CACHE,
|
||||||
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}."
|
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(tunableop_filepath):
|
if os.path.isfile(tunableop_filepath):
|
||||||
|
@ -28,7 +28,6 @@ class FlashCohere(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -30,7 +30,6 @@ class FlashDbrx(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -28,7 +28,6 @@ class FlashGemma(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -15,11 +15,11 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2(FlashCausalLM):
|
class FlashGPT2(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -31,6 +31,7 @@ class FlashGPT2(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -10,6 +10,7 @@ from text_generation_server.models import FlashCausalLM
|
|||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
@ -31,7 +32,6 @@ class FlashLlama(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -27,6 +27,7 @@ from text_generation_server.utils import (
|
|||||||
HeterogeneousNextTokenChooser,
|
HeterogeneousNextTokenChooser,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -318,7 +319,6 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -29,7 +30,6 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -29,7 +30,6 @@ class FlashPhi(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -34,7 +35,6 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -30,7 +30,6 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -33,7 +33,6 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -33,7 +33,6 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -21,6 +21,7 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||||
|
|
||||||
# we split individual characters inside special tokens like [START_DNA]
|
# we split individual characters inside special tokens like [START_DNA]
|
||||||
@ -171,7 +172,6 @@ class GalacticaSharded(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -15,3 +15,12 @@ else:
|
|||||||
cuda_graphs = None
|
cuda_graphs = None
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
||||||
|
# This is overridden at model loading.
|
||||||
|
global MODEL_ID
|
||||||
|
MODEL_ID = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_id(model_id: str):
|
||||||
|
global MODEL_ID
|
||||||
|
MODEL_ID = model_id
|
||||||
|
@ -28,7 +28,6 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -35,7 +35,6 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -22,7 +22,6 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
from text_generation_server.models.vlm_causal_lm import split
|
from text_generation_server.models.vlm_causal_lm import split
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||||
@ -577,6 +576,7 @@ class IdeficsCausalLM(Model):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||||
IdeficsForVisionText2Text,
|
IdeficsForVisionText2Text,
|
||||||
)
|
)
|
||||||
|
@ -412,7 +412,6 @@ class Mamba(Model):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, _rank, world_size = initialize_torch_distributed()
|
self.process_group, _rank, world_size = initialize_torch_distributed()
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
@ -47,7 +47,6 @@ class MPTSharded(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -26,7 +26,6 @@ class OPTSharded(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -26,7 +26,6 @@ class Phi(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -16,6 +16,7 @@ class RW(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ class SantaCoder(CausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -17,6 +17,7 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -536,7 +537,7 @@ class Seq2SeqLM(Model):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
@ -29,7 +29,6 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -21,6 +21,7 @@ from text_generation_server.models.vlm_causal_lm import (
|
|||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
from text_generation_server.models.globals import set_model_id
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
@ -255,6 +256,7 @@ def serve(
|
|||||||
while signal_handler.KEEP_PROCESSING:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
set_model_id(model_id)
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
||||||
|
Loading…
Reference in New Issue
Block a user