add slots filtering kernel

This commit is contained in:
OlivierDehaene 2024-10-25 22:14:26 +02:00
parent b4ebfa52f4
commit 50b394d401
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 82 additions and 13 deletions

View File

@ -77,6 +77,7 @@ from text_generation_server.models.metadata_kernels import (
block_tables_to_ragged,
block_tables_to_padded,
prepare_position_slot_ids,
slots_filtering,
)
tracer = trace.get_tracer(__name__)
@ -500,10 +501,11 @@ class FlashCausalLMBatch(Batch):
# Used to index into tensors
indices = []
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
if not has_triton():
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
@ -531,6 +533,7 @@ class FlashCausalLMBatch(Batch):
num_blocks = 0
max_blocks = 0
max_slots = 0
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
@ -578,8 +581,9 @@ class FlashCausalLMBatch(Batch):
end_slot = self.cu_slots[idx + 1]
slot_length = end_slot - start_slot
# Set slice
slot_filtering_indices[start_slot:end_slot] = True
if not has_triton():
# Set slice
slot_filtering_indices[start_slot:end_slot] = True
cu_slots.append(cumulative_slot_tokens + slot_length)
@ -593,6 +597,7 @@ class FlashCausalLMBatch(Batch):
cumulative_slot_tokens += slot_length
max_blocks = max(max_blocks, len(request_block_table))
max_slots = max(max_slots, slot_length)
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
@ -603,9 +608,18 @@ class FlashCausalLMBatch(Batch):
)
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
if not has_triton():
slots = self.slots[slot_filtering_indices]
else:
slots = self.slots.new_empty(cumulative_slot_tokens)
gpu_cu_slots = cu_slots.to(device)
slots_indexing_start = self.cu_slots.to(device)[indices]
slots_filtering(
max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
)
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None

View File

@ -3,20 +3,28 @@ import triton
import triton.language as tl
from typing import List
from loguru import logger
from typing import List, Optional
from torch.utils._triton import has_triton as has_triton_torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
from text_generation_server.utils.log import log_master
_HAS_TRITON: Optional[bool] = None
def has_triton():
# FIXME: it seems that has_triton_torch is bugged on RocM
# For now, only accept cuda
if SYSTEM == "cuda":
return has_triton_torch()
return False
global _HAS_TRITON
if _HAS_TRITON is None:
# FIXME: it seems that has_triton_torch is bugged on RocM
# For now, only accept cuda
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
if _HAS_TRITON:
log_master(logger.info, "Using optimized Triton indexing kernels.")
return _HAS_TRITON
def block_tables_to_padded(
@ -133,6 +141,53 @@ def prepare_position_slot_ids(
)
def slots_filtering(
max_slots: int,
slots: torch.Tensor,
filtered_slots: torch.Tensor,
cu_slots: torch.Tensor,
slots_start: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
len(slots_start),
)
triton_slots_filtering[grid](
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
)
@triton.jit
def triton_slots_filtering(
# Inputs
slots_ptr,
filtered_slots_ptr,
slots_start_ptr,
cu_slots_ptr,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
filter_start = tl.load(slots_start_ptr + bid)
slot_start = tl.load(cu_slots_ptr + bid)
slot_end = tl.load(cu_slots_ptr + bid + 1)
mask = (slot_start + block_arange) < slot_end
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
@triton.jit
def triton_block_tables_to_padded(
# Inputs