From 51b0c25f374fc42379973b43f18a8b60d9aec770 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 May 2024 15:33:00 +0000 Subject: [PATCH] add model id --- server/text_generation_server/models/bloom.py | 2 ++ .../text_generation_server/models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 18 ++++++++++++------ .../models/flash_cohere.py | 2 ++ .../models/flash_dbrx.py | 2 ++ .../models/flash_gemma.py | 2 ++ .../models/flash_llama.py | 2 ++ .../models/flash_mistral.py | 2 ++ .../models/flash_neox.py | 2 ++ .../text_generation_server/models/flash_phi.py | 2 ++ .../models/flash_qwen2.py | 2 ++ .../text_generation_server/models/flash_rw.py | 2 ++ .../models/flash_santacoder.py | 2 ++ .../models/flash_starcoder2.py | 2 ++ .../text_generation_server/models/galactica.py | 2 ++ .../text_generation_server/models/gpt_neox.py | 2 ++ .../text_generation_server/models/idefics.py | 2 ++ server/text_generation_server/models/mamba.py | 2 ++ server/text_generation_server/models/mpt.py | 2 ++ server/text_generation_server/models/opt.py | 2 ++ server/text_generation_server/models/phi.py | 2 ++ .../models/santacoder.py | 2 ++ .../models/seq2seq_lm.py | 2 ++ server/text_generation_server/models/t5.py | 2 ++ server/text_generation_server/utils/layers.py | 7 +++---- 25 files changed, 60 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 67129ec3..9909e8de 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -46,6 +46,8 @@ class BLOOMSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 93ec6ba4..228817f8 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -486,6 +486,7 @@ class CausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id if use_medusa: raise RuntimeError("Medusa decoding is not enabled for AutoModel") diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6a1bb9d3..a83ab362 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -834,15 +834,21 @@ class FlashCausalLM(Model): if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): torch.cuda.tunable.tuning_enable(True) - tuning_sequences = range(1, 8) - tunableop_filename = f"tunableop_tp{self.world_size}_rank{self.rank}.csv" + tuning_sequences = list(range(1, 3)) + tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv") - logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join(tuning_sequences)}.") - torch.cuda.tunable.read_file(tunableop_filename) + logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.") + + if os.path.isfile(tunableop_filepath): + logger.info(f"The file {tunableop_filepath} already exists and will be reused.") + torch.cuda.tunable.read_file(tunableop_filepath) - for seqlen in range(1, 8): + os.makedirs("/data", exist_ok=True) + + for seqlen in tuning_sequences: + logger.info(f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) - torch.cuda.tunable.write_file(tunableop_filename) + torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) if CUDA_GRAPHS: diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index f85c7722..034652d5 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -28,6 +28,8 @@ class FlashCohere(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index 367d3db0..b92f65c0 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -30,6 +30,8 @@ class FlashDbrx(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 7259b820..74df7d7f 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -29,6 +29,8 @@ class FlashGemma(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 609a188d..a078c010 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -31,6 +31,8 @@ class FlashLlama(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 85e93543..dc9fbe6a 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -318,6 +318,8 @@ class BaseFlashMistral(FlashCausalLM): trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index f82e27db..6e45cc9f 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -29,6 +29,8 @@ class FlashNeoXSharded(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index cb55f9e6..705c95ad 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -29,6 +29,8 @@ class FlashPhi(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index cb3cf6b0..a2fb3820 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -34,6 +34,8 @@ class FlashQwen2(BaseFlashMistral): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index ccf38a0c..93fff9ea 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -30,6 +30,8 @@ class FlashRWSharded(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e66f1bf8..7033645a 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -33,6 +33,8 @@ class FlashSantacoderSharded(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 68e726d8..8610e4a1 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -33,6 +33,8 @@ class FlashStarcoder2(BaseFlashMistral): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a46f86be..e1a2079c 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -171,6 +171,8 @@ class GalacticaSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 1c4cfe7d..275705e4 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -28,6 +28,8 @@ class GPTNeoxSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 30bf4aa6..35b765b7 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -35,6 +35,8 @@ class IDEFICSSharded(IdeficsCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 2aec4f95..dc14054e 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -412,6 +412,8 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, _rank, world_size = initialize_torch_distributed() if world_size > 1: raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 6b3f29a6..06b4b057 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -47,6 +47,8 @@ class MPTSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 703e5b58..a212ebe0 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -26,6 +26,8 @@ class OPTSharded(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index cc4e2505..aaf073fa 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -26,6 +26,8 @@ class Phi(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, _rank, _world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device("cuda") diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 73c21cce..28d70ccc 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -23,6 +23,8 @@ class SantaCoder(CausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e55a661c..5ebffc78 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -536,6 +536,8 @@ class Seq2SeqLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + if use_medusa: raise RuntimeError("Medusa decoding is not enabled for AutoModel") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 3f3cb965..3d02ef0f 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -29,6 +29,8 @@ class T5Sharded(Seq2SeqLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.model_id = model_id + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4b622204..6635be56 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1018,7 +1018,7 @@ try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif IS_ROCM_SYSTEM: - from vllm import pos_encoding_ops + from vllm._C import ops def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( @@ -1339,6 +1339,5 @@ try: freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) - -except ImportError: - pass +except ImportError as e: + logger.warning(f"ImportError in layers.py, beware that this may cause issues later on: {e}")