mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
add slots filtering kernel
This commit is contained in:
parent
b4ebfa52f4
commit
50b394d401
@ -77,6 +77,7 @@ from text_generation_server.models.metadata_kernels import (
|
|||||||
block_tables_to_ragged,
|
block_tables_to_ragged,
|
||||||
block_tables_to_padded,
|
block_tables_to_padded,
|
||||||
prepare_position_slot_ids,
|
prepare_position_slot_ids,
|
||||||
|
slots_filtering,
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -500,10 +501,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Used to index into tensors
|
# Used to index into tensors
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# slots to keep after filtering
|
if not has_triton():
|
||||||
slot_filtering_indices = torch.zeros(
|
# slots to keep after filtering
|
||||||
self.slots.shape[0], dtype=torch.bool, device=device
|
slot_filtering_indices = torch.zeros(
|
||||||
)
|
self.slots.shape[0], dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||||
@ -531,6 +533,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
max_slots = 0
|
||||||
cumulative_slot_tokens = 0
|
cumulative_slot_tokens = 0
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
@ -578,8 +581,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
end_slot = self.cu_slots[idx + 1]
|
end_slot = self.cu_slots[idx + 1]
|
||||||
slot_length = end_slot - start_slot
|
slot_length = end_slot - start_slot
|
||||||
|
|
||||||
# Set slice
|
if not has_triton():
|
||||||
slot_filtering_indices[start_slot:end_slot] = True
|
# Set slice
|
||||||
|
slot_filtering_indices[start_slot:end_slot] = True
|
||||||
|
|
||||||
cu_slots.append(cumulative_slot_tokens + slot_length)
|
cu_slots.append(cumulative_slot_tokens + slot_length)
|
||||||
|
|
||||||
@ -593,6 +597,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
cumulative_slot_tokens += slot_length
|
cumulative_slot_tokens += slot_length
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
max_slots = max(max_slots, slot_length)
|
||||||
|
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
@ -603,9 +608,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||||
|
|
||||||
slots = self.slots[slot_filtering_indices]
|
|
||||||
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
if not has_triton():
|
||||||
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
else:
|
||||||
|
slots = self.slots.new_empty(cumulative_slot_tokens)
|
||||||
|
gpu_cu_slots = cu_slots.to(device)
|
||||||
|
slots_indexing_start = self.cu_slots.to(device)[indices]
|
||||||
|
slots_filtering(
|
||||||
|
max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
|
||||||
|
)
|
||||||
|
|
||||||
if self.prefilling:
|
if self.prefilling:
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids = None
|
position_ids = None
|
||||||
|
@ -3,20 +3,28 @@ import triton
|
|||||||
|
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from typing import List
|
from loguru import logger
|
||||||
|
from typing import List, Optional
|
||||||
from torch.utils._triton import has_triton as has_triton_torch
|
from torch.utils._triton import has_triton as has_triton_torch
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
SYSTEM,
|
SYSTEM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
_HAS_TRITON: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
def has_triton():
|
def has_triton():
|
||||||
# FIXME: it seems that has_triton_torch is bugged on RocM
|
global _HAS_TRITON
|
||||||
# For now, only accept cuda
|
if _HAS_TRITON is None:
|
||||||
if SYSTEM == "cuda":
|
# FIXME: it seems that has_triton_torch is bugged on RocM
|
||||||
return has_triton_torch()
|
# For now, only accept cuda
|
||||||
return False
|
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
|
||||||
|
if _HAS_TRITON:
|
||||||
|
log_master(logger.info, "Using optimized Triton indexing kernels.")
|
||||||
|
|
||||||
|
return _HAS_TRITON
|
||||||
|
|
||||||
|
|
||||||
def block_tables_to_padded(
|
def block_tables_to_padded(
|
||||||
@ -133,6 +141,53 @@ def prepare_position_slot_ids(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def slots_filtering(
|
||||||
|
max_slots: int,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
filtered_slots: torch.Tensor,
|
||||||
|
cu_slots: torch.Tensor,
|
||||||
|
slots_start: torch.Tensor,
|
||||||
|
):
|
||||||
|
def grid(meta):
|
||||||
|
return (
|
||||||
|
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
|
||||||
|
len(slots_start),
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_slots_filtering[grid](
|
||||||
|
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def triton_slots_filtering(
|
||||||
|
# Inputs
|
||||||
|
slots_ptr,
|
||||||
|
filtered_slots_ptr,
|
||||||
|
slots_start_ptr,
|
||||||
|
cu_slots_ptr,
|
||||||
|
# Const values
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
# Position in batch
|
||||||
|
bid = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
filter_start = tl.load(slots_start_ptr + bid)
|
||||||
|
|
||||||
|
slot_start = tl.load(cu_slots_ptr + bid)
|
||||||
|
slot_end = tl.load(cu_slots_ptr + bid + 1)
|
||||||
|
|
||||||
|
mask = (slot_start + block_arange) < slot_end
|
||||||
|
|
||||||
|
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
|
||||||
|
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def triton_block_tables_to_padded(
|
def triton_block_tables_to_padded(
|
||||||
# Inputs
|
# Inputs
|
||||||
|
Loading…
Reference in New Issue
Block a user