From a040a5906851195fa1bf67dfd2ccc4c2fb4bcb7f Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 May 2024 08:46:14 +0000 Subject: [PATCH] refactor model_id, make tunableop default --- server/text_generation_server/models/bloom.py | 1 - .../models/flash_causal_lm.py | 14 +++++++++----- .../text_generation_server/models/flash_cohere.py | 1 - server/text_generation_server/models/flash_dbrx.py | 1 - .../text_generation_server/models/flash_gemma.py | 1 - server/text_generation_server/models/flash_gpt2.py | 5 +++-- .../text_generation_server/models/flash_llama.py | 2 +- .../text_generation_server/models/flash_mistral.py | 2 +- server/text_generation_server/models/flash_neox.py | 2 +- server/text_generation_server/models/flash_phi.py | 2 +- .../text_generation_server/models/flash_qwen2.py | 2 +- server/text_generation_server/models/flash_rw.py | 1 - .../models/flash_santacoder.py | 1 - .../models/flash_starcoder2.py | 1 - server/text_generation_server/models/galactica.py | 2 +- server/text_generation_server/models/globals.py | 9 +++++++++ server/text_generation_server/models/gpt_neox.py | 1 - server/text_generation_server/models/idefics.py | 1 - .../models/idefics_causal_lm.py | 2 +- server/text_generation_server/models/mamba.py | 1 - server/text_generation_server/models/mpt.py | 1 - server/text_generation_server/models/opt.py | 1 - server/text_generation_server/models/phi.py | 1 - server/text_generation_server/models/rw.py | 1 + server/text_generation_server/models/santacoder.py | 1 - server/text_generation_server/models/seq2seq_lm.py | 3 ++- server/text_generation_server/models/t5.py | 1 - server/text_generation_server/server.py | 2 ++ 28 files changed, 33 insertions(+), 30 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 6e7f2f1e..65c9f317 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -46,7 +46,6 @@ class BLOOMSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b0ac9ece..333efe33 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -31,6 +31,7 @@ from text_generation_server.models.cache_manager import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS +import text_generation_server.models.globals as tgi_globals from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -827,11 +828,14 @@ class FlashCausalLM(Model): ) if SYSTEM == "rocm": - if os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): + if ( + os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None + or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" + ): + if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": torch.cuda.tunable.tuning_enable(True) - if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False): + if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None: tuning_sequences = [ int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") @@ -841,11 +845,11 @@ class FlashCausalLM(Model): tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) logger.info( - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}." + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." ) if os.path.isfile(tunableop_filepath): diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index d955fca7..8edaaa35 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -28,7 +28,6 @@ class FlashCohere(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index 60c032ad..6a9b9d7f 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -30,7 +30,6 @@ class FlashDbrx(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 8783a259..70f1b65c 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -28,7 +28,6 @@ class FlashGemma(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 5781f55e..2511148d 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -15,11 +15,11 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.models import CausalLM +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import SYSTEM - class FlashGPT2(FlashCausalLM): def __init__( @@ -31,6 +31,7 @@ class FlashGPT2(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 5c552379..7395a41f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -10,6 +10,7 @@ from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) + from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -31,7 +32,6 @@ class FlashLlama(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 018ddd10..295fcc41 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -27,6 +27,7 @@ from text_generation_server.utils import ( HeterogeneousNextTokenChooser, StoppingCriteria, ) +from text_generation_server.models import CausalLM tracer = trace.get_tracer(__name__) @@ -318,7 +319,6 @@ class BaseFlashMistral(FlashCausalLM): trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index eda1d658..fe1ebbaf 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -15,6 +15,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models import CausalLM tracer = trace.get_tracer(__name__) @@ -29,7 +30,6 @@ class FlashNeoXSharded(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 9dd576bc..1934254e 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -15,6 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.models import CausalLM tracer = trace.get_tracer(__name__) @@ -29,7 +30,6 @@ class FlashPhi(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 482d5fd6..511e33e0 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.models import CausalLM tracer = trace.get_tracer(__name__) @@ -34,7 +35,6 @@ class FlashQwen2(BaseFlashMistral): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 68c01e5a..ccc90179 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -30,7 +30,6 @@ class FlashRWSharded(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 9b48a5e6..e1add297 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -33,7 +33,6 @@ class FlashSantacoderSharded(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index d498725a..80323fb6 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -33,7 +33,6 @@ class FlashStarcoder2(BaseFlashMistral): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index c164703c..93d004d7 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -21,6 +21,7 @@ from text_generation_server.utils import ( Weights, ) + # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py # we split individual characters inside special tokens like [START_DNA] @@ -171,7 +172,6 @@ class GalacticaSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6f8d1017..e8a11958 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -15,3 +15,12 @@ else: cuda_graphs = None CUDA_GRAPHS = cuda_graphs + +# This is overridden at model loading. +global MODEL_ID +MODEL_ID = None + + +def set_model_id(model_id: str): + global MODEL_ID + MODEL_ID = model_id diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 83007609..92fa5ce4 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -28,7 +28,6 @@ class GPTNeoxSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index a3000cfa..816c5e75 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -35,7 +35,6 @@ class IDEFICSSharded(IdeficsCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index e78a9655..dd26cc06 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -22,7 +22,6 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.models.vlm_causal_lm import split - import re IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") @@ -577,6 +576,7 @@ class IdeficsCausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, ) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 6a681aa1..36386365 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -412,7 +412,6 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, _rank, world_size = initialize_torch_distributed() if world_size > 1: diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 4525c417..6f6e837f 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -47,7 +47,6 @@ class MPTSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index f706c7dc..48584734 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -26,7 +26,6 @@ class OPTSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index dff76084..d4dff836 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -26,7 +26,6 @@ class Phi(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, _rank, _world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index d4764ded..c347c47d 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -16,6 +16,7 @@ class RW(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if speculator: raise RuntimeError("Medusa decoding is not enabled for AutoModel") diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index e8ca2235..188faf21 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -23,7 +23,6 @@ class SantaCoder(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id if torch.cuda.is_available(): device = torch.device("cuda") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 80b70f84..c5473107 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -17,6 +17,7 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling + tracer = trace.get_tracer(__name__) @@ -536,7 +537,7 @@ class Seq2SeqLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id + if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index dc49d7bd..674e9318 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -29,7 +29,6 @@ class T5Sharded(Seq2SeqLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.model_id = model_id self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 92126fe6..152e10bd 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -21,6 +21,7 @@ from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch +from text_generation_server.models.globals import set_model_id class SignalHandler: @@ -255,6 +256,7 @@ def serve( while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) + set_model_id(model_id) asyncio.run( serve_inner( model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code