From 50b394d4011c14185d7557f9ab945e2b63ad481d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 25 Oct 2024 22:14:26 +0200 Subject: [PATCH] add slots filtering kernel --- .../models/flash_causal_lm.py | 28 ++++++-- .../models/metadata_kernels.py | 67 +++++++++++++++++-- 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ad6123d0..87e904f4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -77,6 +77,7 @@ from text_generation_server.models.metadata_kernels import ( block_tables_to_ragged, block_tables_to_padded, prepare_position_slot_ids, + slots_filtering, ) tracer = trace.get_tracer(__name__) @@ -500,10 +501,11 @@ class FlashCausalLMBatch(Batch): # Used to index into tensors indices = [] - # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) + if not has_triton(): + # slots to keep after filtering + slot_filtering_indices = torch.zeros( + self.slots.shape[0], dtype=torch.bool, device=device + ) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) @@ -531,6 +533,7 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 + max_slots = 0 cumulative_slot_tokens = 0 for i, request_id in enumerate(request_ids): @@ -578,8 +581,9 @@ class FlashCausalLMBatch(Batch): end_slot = self.cu_slots[idx + 1] slot_length = end_slot - start_slot - # Set slice - slot_filtering_indices[start_slot:end_slot] = True + if not has_triton(): + # Set slice + slot_filtering_indices[start_slot:end_slot] = True cu_slots.append(cumulative_slot_tokens + slot_length) @@ -593,6 +597,7 @@ class FlashCausalLMBatch(Batch): cumulative_slot_tokens += slot_length max_blocks = max(max_blocks, len(request_block_table)) + max_slots = max(max_slots, slot_length) all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] @@ -603,9 +608,18 @@ class FlashCausalLMBatch(Batch): ) prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] cu_slots = torch.tensor(cu_slots, dtype=torch.int64) + if not has_triton(): + slots = self.slots[slot_filtering_indices] + else: + slots = self.slots.new_empty(cumulative_slot_tokens) + gpu_cu_slots = cu_slots.to(device) + slots_indexing_start = self.cu_slots.to(device)[indices] + slots_filtering( + max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start + ) + if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index 17312efb..b3e2160d 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -3,20 +3,28 @@ import triton import triton.language as tl -from typing import List +from loguru import logger +from typing import List, Optional from torch.utils._triton import has_triton as has_triton_torch from text_generation_server.utils.import_utils import ( SYSTEM, ) +from text_generation_server.utils.log import log_master + +_HAS_TRITON: Optional[bool] = None def has_triton(): - # FIXME: it seems that has_triton_torch is bugged on RocM - # For now, only accept cuda - if SYSTEM == "cuda": - return has_triton_torch() - return False + global _HAS_TRITON + if _HAS_TRITON is None: + # FIXME: it seems that has_triton_torch is bugged on RocM + # For now, only accept cuda + _HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False + if _HAS_TRITON: + log_master(logger.info, "Using optimized Triton indexing kernels.") + + return _HAS_TRITON def block_tables_to_padded( @@ -133,6 +141,53 @@ def prepare_position_slot_ids( ) +def slots_filtering( + max_slots: int, + slots: torch.Tensor, + filtered_slots: torch.Tensor, + cu_slots: torch.Tensor, + slots_start: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_slots, meta["BLOCK_SIZE"]), + len(slots_start), + ) + + triton_slots_filtering[grid]( + slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256 + ) + + +@triton.jit +def triton_slots_filtering( + # Inputs + slots_ptr, + filtered_slots_ptr, + slots_start_ptr, + cu_slots_ptr, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + filter_start = tl.load(slots_start_ptr + bid) + + slot_start = tl.load(cu_slots_ptr + bid) + slot_end = tl.load(cu_slots_ptr + bid + 1) + + mask = (slot_start + block_arange) < slot_end + + slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask) + tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask) + + @triton.jit def triton_block_tables_to_padded( # Inputs