mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
add slots filtering kernel
This commit is contained in:
parent
b4ebfa52f4
commit
50b394d401
@ -77,6 +77,7 @@ from text_generation_server.models.metadata_kernels import (
|
||||
block_tables_to_ragged,
|
||||
block_tables_to_padded,
|
||||
prepare_position_slot_ids,
|
||||
slots_filtering,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -500,6 +501,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Used to index into tensors
|
||||
indices = []
|
||||
|
||||
if not has_triton():
|
||||
# slots to keep after filtering
|
||||
slot_filtering_indices = torch.zeros(
|
||||
self.slots.shape[0], dtype=torch.bool, device=device
|
||||
@ -531,6 +533,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
num_blocks = 0
|
||||
max_blocks = 0
|
||||
max_slots = 0
|
||||
cumulative_slot_tokens = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
@ -578,6 +581,7 @@ class FlashCausalLMBatch(Batch):
|
||||
end_slot = self.cu_slots[idx + 1]
|
||||
slot_length = end_slot - start_slot
|
||||
|
||||
if not has_triton():
|
||||
# Set slice
|
||||
slot_filtering_indices[start_slot:end_slot] = True
|
||||
|
||||
@ -593,6 +597,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
cumulative_slot_tokens += slot_length
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
max_slots = max(max_slots, slot_length)
|
||||
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
@ -603,9 +608,18 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||
|
||||
if not has_triton():
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
else:
|
||||
slots = self.slots.new_empty(cumulative_slot_tokens)
|
||||
gpu_cu_slots = cu_slots.to(device)
|
||||
slots_indexing_start = self.cu_slots.to(device)[indices]
|
||||
slots_filtering(
|
||||
max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
|
||||
)
|
||||
|
||||
if self.prefilling:
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids = None
|
||||
|
@ -3,20 +3,28 @@ import triton
|
||||
|
||||
import triton.language as tl
|
||||
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
from torch.utils._triton import has_triton as has_triton_torch
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
SYSTEM,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
_HAS_TRITON: Optional[bool] = None
|
||||
|
||||
|
||||
def has_triton():
|
||||
global _HAS_TRITON
|
||||
if _HAS_TRITON is None:
|
||||
# FIXME: it seems that has_triton_torch is bugged on RocM
|
||||
# For now, only accept cuda
|
||||
if SYSTEM == "cuda":
|
||||
return has_triton_torch()
|
||||
return False
|
||||
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
|
||||
if _HAS_TRITON:
|
||||
log_master(logger.info, "Using optimized Triton indexing kernels.")
|
||||
|
||||
return _HAS_TRITON
|
||||
|
||||
|
||||
def block_tables_to_padded(
|
||||
@ -133,6 +141,53 @@ def prepare_position_slot_ids(
|
||||
)
|
||||
|
||||
|
||||
def slots_filtering(
|
||||
max_slots: int,
|
||||
slots: torch.Tensor,
|
||||
filtered_slots: torch.Tensor,
|
||||
cu_slots: torch.Tensor,
|
||||
slots_start: torch.Tensor,
|
||||
):
|
||||
def grid(meta):
|
||||
return (
|
||||
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
|
||||
len(slots_start),
|
||||
)
|
||||
|
||||
triton_slots_filtering[grid](
|
||||
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_slots_filtering(
|
||||
# Inputs
|
||||
slots_ptr,
|
||||
filtered_slots_ptr,
|
||||
slots_start_ptr,
|
||||
cu_slots_ptr,
|
||||
# Const values
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||
pid = tl.program_id(axis=0)
|
||||
# Position in batch
|
||||
bid = tl.program_id(axis=1)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
filter_start = tl.load(slots_start_ptr + bid)
|
||||
|
||||
slot_start = tl.load(cu_slots_ptr + bid)
|
||||
slot_end = tl.load(cu_slots_ptr + bid + 1)
|
||||
|
||||
mask = (slot_start + block_arange) < slot_end
|
||||
|
||||
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
|
||||
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_block_tables_to_padded(
|
||||
# Inputs
|
||||
|
Loading…
Reference in New Issue
Block a user