From ff5e16b0e2aa50cae9f5689953ddc3a549342b99 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 May 2024 13:29:20 +0000 Subject: [PATCH] working tunable --- .../custom_modeling/flash_llama_modeling.py | 1 - .../models/flash_causal_lm.py | 56 ++++++++----------- .../utils/paged_attention.py | 1 - 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4cc9da8b..72ccc1cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -17,7 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch import torch.distributed diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3aca1042..b9db0fe1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -769,10 +769,7 @@ class FlashCausalLM(Model): if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) - - logger.info("calling self.generate_token(batch)") _, batch, _ = self.generate_token(batch) - logger.info("end it") except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " @@ -814,6 +811,21 @@ class FlashCausalLM(Model): self.device, ) + if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): + if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): + torch.cuda.tunable.tuning_enable(True) + + tuning_sequences = range(1, 8) + tunableop_filename = f"tunableop_tp{self.world_size}_rank{self.rank}.csv" + + logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join(tuning_sequences)}.") + torch.cuda.tunable.read_file(tunableop_filename) + + for seqlen in range(1, 8): + self.tunableop_warmup(seqlen) + torch.cuda.tunable.write_file(tunableop_filename) + torch.cuda.tunable.tuning_enable(False) + if CUDA_GRAPHS: try: logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") @@ -826,48 +838,24 @@ class FlashCausalLM(Model): else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - # if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - # if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): - # torch.cuda.tunable.tuning_enable(True) - # logger.info("enable tuning here") - - logger.info("PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.") - for seqlen in range(1, 3): - logger.info(f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen, max_s, max_bt) - logger.info("call write file") - torch.cuda.tunable.write_file() - torch.cuda.tunable.tuning_enable(False) - - logger.info("finished tunable op") return int(num_blocks * BLOCK_SIZE) - def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int): + def tunableop_warmup(self, seqlen: int): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - - input_lengths = ( - torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s - ) - bs = 1 - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) - ) kv_cache = get_cache_manager().kv_cache - logger.info("call self.model.forward") + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=kv_cache, - block_tables=block_tables, + cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32), + kv_cache=get_cache_manager().kv_cache, + block_tables=None, + input_lengths=None, slots=slots, - input_lengths=input_lengths, - max_s=max_s, + max_s=seqlen, lm_head_indices=None, ) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index dda15fa6..98784f69 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -10,7 +10,6 @@ try: except Exception as e: raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}") - def reshape_and_cache( key: torch.Tensor, value: torch.Tensor,