mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
add chunked attn support
This commit is contained in:
parent
5861da1ad7
commit
e5618d6e40
@ -12,8 +12,9 @@ from text_generation_server.utils import initialize_torch_distributed
|
|||||||
|
|
||||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -27,6 +28,124 @@ REPLICATED_ATTENTION_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def cdiv(a: int, b: int) -> int:
|
||||||
|
"""Ceiling division."""
|
||||||
|
return -(a // -b)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py
|
||||||
|
def make_local_attention_virtual_batches(
|
||||||
|
attn_chunk_size: int,
|
||||||
|
query_start_loc_np: np.ndarray,
|
||||||
|
seq_lens_np: np.ndarray,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
page_size: int = 0,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||||
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||||
|
actual_batch_size = seq_lens_np.shape[0]
|
||||||
|
# Handle if we are starting in the middle of a local attention block,
|
||||||
|
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||||
|
# the number of tokens that are not in the first local attention block and
|
||||||
|
# then we can simply use a cdiv for the rest.
|
||||||
|
# For example if we have:
|
||||||
|
# attn_chunk_size = 4
|
||||||
|
# q_seqlens = [4, 10, 5]
|
||||||
|
# k_seqlens = [6, 17, 9]
|
||||||
|
# Then we would get:
|
||||||
|
# new_tokens_in_first_block = [2, 1, 4]
|
||||||
|
# local_blocks = [2, 4, 2]
|
||||||
|
q_tokens_in_first_block = np.minimum(
|
||||||
|
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
||||||
|
).astype(np.int32)
|
||||||
|
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||||
|
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
||||||
|
|
||||||
|
# Once we know the number of local blocks we can compute the request spans
|
||||||
|
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||||
|
# have to make,
|
||||||
|
# For the above example we would get:
|
||||||
|
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||||
|
#
|
||||||
|
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||||
|
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||||
|
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||||
|
cu_num_blocks = np.cumsum(local_blocks)
|
||||||
|
virtual_batches = cu_num_blocks[-1]
|
||||||
|
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||||
|
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||||
|
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||||
|
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||||
|
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||||
|
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||||
|
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||||
|
# first and last blocks could be partial
|
||||||
|
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||||
|
# set the first block since this may be a partial block
|
||||||
|
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||||
|
# set the remaining blocks
|
||||||
|
seqlens_q_local[arange > 0] = np.minimum(
|
||||||
|
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
||||||
|
)[arange > 0]
|
||||||
|
|
||||||
|
# convert from q_seqlens to cu_seqlens_q
|
||||||
|
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
||||||
|
|
||||||
|
# compute the seqlens_k_local,
|
||||||
|
# basically a full local attention block for all but the last block in each
|
||||||
|
# batch
|
||||||
|
# For our example this will be:
|
||||||
|
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||||
|
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
||||||
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||||
|
|
||||||
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
||||||
|
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
||||||
|
)
|
||||||
|
|
||||||
|
# For the example the local attention blocks start at:
|
||||||
|
# _b0_ _____b1_____ _b2_
|
||||||
|
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||||
|
block_starts = k_seqstarts_absolute // page_size
|
||||||
|
assert attn_chunk_size % page_size == 0, (
|
||||||
|
f"attn_chunk_size {attn_chunk_size} is not "
|
||||||
|
f"divisible by page_size {page_size}"
|
||||||
|
)
|
||||||
|
pages_per_local_batch = attn_chunk_size // page_size
|
||||||
|
|
||||||
|
# Create a block_table for the local attention blocks
|
||||||
|
# For out example if we have a block-table like (assuming page_size=2):
|
||||||
|
# block_table = [
|
||||||
|
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||||
|
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||||
|
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||||
|
# ]
|
||||||
|
# Then for the local batches we would want a block-table like
|
||||||
|
# block_table_local = [
|
||||||
|
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||||
|
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||||
|
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||||
|
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||||
|
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||||
|
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||||
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||||
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||||
|
# ]
|
||||||
|
block_indices = np.broadcast_to(
|
||||||
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||||
|
(virtual_batches, pages_per_local_batch),
|
||||||
|
) + np.expand_dims(block_starts, axis=1)
|
||||||
|
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||||
|
batch_indices = np.repeat(
|
||||||
|
np.arange(actual_batch_size, dtype=np.int32),
|
||||||
|
local_blocks * pages_per_local_batch,
|
||||||
|
)
|
||||||
|
block_table_local = block_table[batch_indices, block_indices].view(
|
||||||
|
virtual_batches, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
||||||
|
|
||||||
|
|
||||||
# # Qwen2VL
|
# # Qwen2VL
|
||||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
# "tgi"
|
# "tgi"
|
||||||
@ -51,8 +170,13 @@ def tgi_flash_attention_forward(
|
|||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
use_sdpa: Optional[bool] = False,
|
use_sdpa: Optional[bool] = False,
|
||||||
|
local_seqlen: Optional[Seqlen] = None,
|
||||||
|
local_block_tables: Optional[torch.Tensor] = None,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
|
if module.use_rope:
|
||||||
|
seqlen = local_seqlen
|
||||||
|
block_tables = local_block_tables
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
@ -313,7 +437,9 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
|
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def pre_process_inputs(self, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids.unsqueeze(0),
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
"position_ids": position_ids.unsqueeze(0),
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
@ -372,6 +498,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
logits = self.model.original_forward(
|
logits = self.model.original_forward(
|
||||||
@ -396,6 +524,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
attention_mask=inputs.get("attention_mask", None),
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
use_sdpa=inputs.get("use_sdpa", False),
|
use_sdpa=inputs.get("use_sdpa", False),
|
||||||
cache_position=inputs.get("cache_position", None),
|
cache_position=inputs.get("cache_position", None),
|
||||||
|
local_seqlen=inputs.get("local_seqlen", None),
|
||||||
|
local_block_tables=inputs.get("local_block_tables", None),
|
||||||
).logits
|
).logits
|
||||||
|
|
||||||
logits = self.post_process_outputs(logits, lm_head_indices)
|
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||||
@ -480,7 +610,10 @@ class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
def post_process_outputs(self, logits, lm_head_indices):
|
def post_process_outputs(self, logits, lm_head_indices):
|
||||||
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
|
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
|
||||||
|
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def pre_process_inputs(self, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
|
|
||||||
input_ids = input_ids.unsqueeze(0)
|
input_ids = input_ids.unsqueeze(0)
|
||||||
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
|
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
|
||||||
return {"input_ids": input_ids, "position_ids": position_ids}
|
return {"input_ids": input_ids, "position_ids": position_ids}
|
||||||
@ -542,7 +675,11 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
|
|
||||||
return final_attention_mask
|
return final_attention_mask
|
||||||
|
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def pre_process_inputs(self, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
|
cu_seqlen_prefill = kwargs["cu_seqlen_prefill"]
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": input_ids.unsqueeze(0),
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
"position_ids": position_ids.unsqueeze(0),
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
@ -559,8 +696,48 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
|
|
||||||
|
|
||||||
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def pre_process_inputs(self, **kwargs):
|
||||||
inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill)
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
|
seqlen = kwargs["seqlen"]
|
||||||
|
block_tables = kwargs["block_tables"]
|
||||||
|
|
||||||
|
inputs = super().pre_process_inputs(**kwargs)
|
||||||
inputs["cache_position"] = position_ids
|
inputs["cache_position"] = position_ids
|
||||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}")
|
||||||
|
cu_seqlen_k = seqlen.cu_seqlen_k
|
||||||
|
cu_seqlen_q = seqlen.cu_seqlen_q
|
||||||
|
seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||||
|
(
|
||||||
|
seqlens_q_local_np,
|
||||||
|
virt_q_cu_seqlens_np,
|
||||||
|
virt_k_seqlens_np,
|
||||||
|
virt_block_table,
|
||||||
|
) = make_local_attention_virtual_batches(
|
||||||
|
self.model.config.text_config.attention_chunk_size,
|
||||||
|
cu_seqlen_q.cpu().numpy(),
|
||||||
|
seq_lens_np.cpu().numpy(),
|
||||||
|
block_tables,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
local_seqlen = Seqlen(
|
||||||
|
input_lengths=torch.from_numpy(virt_k_seqlens_np).to(
|
||||||
|
input_ids.device, non_blocking=True
|
||||||
|
),
|
||||||
|
cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to(
|
||||||
|
input_ids.device, non_blocking=True
|
||||||
|
),
|
||||||
|
cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||||
|
input_ids.device, non_blocking=True
|
||||||
|
),
|
||||||
|
max_q=int(seqlens_q_local_np.max()),
|
||||||
|
max_k=int(virt_k_seqlens_np.max()),
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs["local_seqlen"] = local_seqlen
|
||||||
|
inputs["local_block_tables"] = virt_block_table
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
Loading…
Reference in New Issue
Block a user