text-generation-inference/server/text_generation_server/models/metadata_kernels.py

361 lines
9.5 KiB
Python
Raw Normal View History

import torch
import triton
import triton.language as tl
from collections import Counter
from loguru import logger
from typing import List, Optional
from torch.utils._triton import has_triton as has_triton_torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
from text_generation_server.utils.log import log_master
_HAS_TRITON: Optional[bool] = None
def has_triton():
global _HAS_TRITON
if _HAS_TRITON is None:
# FIXME: it seems that has_triton_torch is bugged on RocM
# For now, only accept cuda
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
if _HAS_TRITON:
log_master(logger.info, "Using optimized Triton indexing kernels.")
return _HAS_TRITON
def block_tables_to_padded(
max_blocks: int,
cu_seqlen: torch.Tensor,
block_tables: torch.Tensor,
block_tables_ragged: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_blocks, meta["BLOCK_SIZE"]),
len(block_tables),
)
triton_block_tables_to_padded[grid](
cu_seqlen,
block_tables,
block_tables_ragged,
block_tables.shape[1],
BLOCK_SIZE=256,
)
def block_tables_to_ragged(
*,
block_tables: torch.Tensor,
input_lengths: List[int],
cache_lengths: List[int],
input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
max_current_length: int,
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(cache_lengths)
total_len = sum(input_lengths) + sum(cache_lengths)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
if has_triton():
cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)
torch.cumsum(
input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0
)
def grid(meta):
return (
triton.cdiv(max_current_length, meta["BLOCK_SIZE"]),
len(cache_lengths),
)
triton_block_tables_to_ragged[grid](
cu_seqlen,
block_tables,
block_tables_ragged,
block_tables.shape[1],
BLOCK_SIZE=256,
)
else:
offset = 0
for i, (input_length, cache_length) in enumerate(
zip(input_lengths, cache_lengths)
):
seq_len = cache_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged
def copy_next_input_ids_inplace(
max_next_input_ids: int,
all_input_ids: torch.Tensor,
cache_lengths: torch.Tensor,
input_lengths: torch.Tensor,
prompt_lengths: torch.Tensor,
next_input_ids: torch.Tensor,
cu_accepted_ids: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]),
len(all_input_ids),
)
triton_copy_next_input_ids_inplace[grid](
all_input_ids,
cache_lengths,
input_lengths,
prompt_lengths,
next_input_ids,
cu_accepted_ids,
all_input_ids.shape[1],
BLOCK_SIZE=16,
)
def prepare_position_slot_ids(
max_input_length: int,
cache_lengths: torch.Tensor,
cu_seqlen: torch.Tensor,
cu_slots: torch.Tensor,
position_ids: torch.Tensor,
slot_indices: torch.Tensor,
slots: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_input_length, meta["BLOCK_SIZE"]),
len(cache_lengths),
)
triton_prepare_position_slot_ids[grid](
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
)
SLOTS = slots[slot_indices]
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
import ipdb
ipdb.set_trace()
def slots_filtering(
max_slots: int,
slots: torch.Tensor,
filtered_slots: torch.Tensor,
cu_slots: torch.Tensor,
slots_start: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
len(slots_start),
)
triton_slots_filtering[grid](
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
)
assert torch.all(slots[slots_start] == filtered_slots[cu_slots[:-1]])
# assert torch.unique(slots).shape == slots.shape, (
# f"Slots {slots} {Counter(slots.tolist()).most_common(3)}"
# )
@triton.jit
def triton_slots_filtering(
# Inputs
slots_ptr,
filtered_slots_ptr,
slots_start_ptr,
cu_slots_ptr,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
filter_start = tl.load(slots_start_ptr + bid)
slot_start = tl.load(cu_slots_ptr + bid)
slot_end = tl.load(cu_slots_ptr + bid + 1)
mask = (slot_start + block_arange) < slot_end
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
@triton.jit
def triton_block_tables_to_padded(
# Inputs
cu_seqlen_ptr,
# Outputs
block_tables_ptr,
block_tables_ragged_ptr,
# Stride
stride_block_tables,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
mask = (seq_start + block_arange) < seq_end
blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask)
tl.store(
block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask
)
@triton.jit
def triton_block_tables_to_ragged(
# Inputs
cu_seqlen_ptr,
# Outputs
block_tables_ptr,
block_tables_ragged_ptr,
# Stride
stride_block_tables,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
mask = (seq_start + block_arange) < seq_end
blocks = tl.load(
block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask
)
tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask)
@triton.jit
def triton_copy_next_input_ids_inplace(
# Inputs
all_input_ids_ptr,
cache_lengths_ptr,
input_lengths_ptr,
prompt_lengths_ptr,
next_input_ids_ptr,
cu_accepted_ids_ptr,
# Stride
stride_all_input_ids,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in max_accepted_ids / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
# Used for correctly indexing in all_input_ids
cache_length = tl.load(cache_lengths_ptr + bid)
input_length = tl.load(input_lengths_ptr + bid)
prompt_length = tl.load(prompt_lengths_ptr + bid)
# Start/End of next_input_ids for this request
next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid)
next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1)
# Mask values out of range
mask = (next_input_ids_start + block_arange) < next_input_ids_end
# Mask values for request still prefilling
decode_mask = (cache_length + input_length + block_arange) >= prompt_length
mask = mask & decode_mask
# Load this request next input ids
next_input_ids = tl.load(
next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask
)
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
tl.store(
all_input_ids_ptr
+ stride_all_input_ids * bid
+ cache_length
+ input_length
+ block_arange,
next_input_ids,
mask=mask,
)
@triton.jit
def triton_prepare_position_slot_ids(
# Inputs
cache_lengths_ptr,
cu_seqlen_ptr,
cu_slots_ptr,
# Outputs
position_ids_ptr,
slot_indices_ptr,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in max_input_length / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
cache_length = tl.load(cache_lengths_ptr + bid)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
slot_start = tl.load(cu_slots_ptr + bid)
mask = (seq_start + block_arange) < seq_end
tl.store(
position_ids_ptr + seq_start + block_arange,
cache_length + block_arange,
mask=mask,
)
tl.store(
slot_indices_ptr + seq_start + block_arange,
slot_start + cache_length + block_arange,
mask=mask,
)