mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Factoring cu_seqlen_qk for better abstracting over every model.
This commit is contained in:
parent
65980ed75a
commit
4b1364da92
@ -1,6 +1,8 @@
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from .common import Seqlen
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
|
31
server/text_generation_server/layers/attention/common.py
Normal file
31
server/text_generation_server/layers/attention/common.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(self, input_lengths):
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
if FLASH_DECODING:
|
||||||
|
device = self.input_lengths.device
|
||||||
|
shape = self.input_lengths.shape
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
shape[0] + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlen_k = torch.empty(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
cu_seqlen_k[0] = 0
|
||||||
|
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
|
else:
|
||||||
|
self.cu_seqlen_q = None
|
||||||
|
self.cu_seqlen_k = None
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
@ -40,8 +41,7 @@ def paged_attention(
|
|||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_q: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
cu_seqlen_k: torch.Tensor,
|
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
@ -66,7 +66,6 @@ def paged_attention(
|
|||||||
block_size = BLOCK_SIZE
|
block_size = BLOCK_SIZE
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = cu_seqlen_k
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
@ -88,8 +87,8 @@ def paged_attention(
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
None,
|
None,
|
||||||
cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
None,
|
None,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
@ -106,6 +105,7 @@ def paged_attention(
|
|||||||
)
|
)
|
||||||
return out2[0]
|
return out2[0]
|
||||||
else:
|
else:
|
||||||
|
input_lengths = seqlen.input_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (
|
use_v1 = max_s <= 8192 and (
|
||||||
|
@ -260,8 +260,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
@ -314,8 +313,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -389,8 +387,7 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
@ -404,8 +401,7 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
@ -469,23 +465,6 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
if cu_seqlen_prefill is None and FLASH_DECODING:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
input_lengths.shape[0] + 1,
|
|
||||||
device=input_ids.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlen_k = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
|
||||||
),
|
|
||||||
input_lengths.cumsum(dim=-1),
|
|
||||||
]
|
|
||||||
).to(dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
cu_seqlen_q = None
|
|
||||||
cu_seqlen_k = input_lengths
|
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
@ -496,8 +475,7 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -344,7 +344,6 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -253,7 +253,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -253,7 +253,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -173,8 +173,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
):
|
):
|
||||||
@ -218,8 +217,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -356,8 +354,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
):
|
):
|
||||||
@ -372,8 +369,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
@ -443,23 +439,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
if cu_seqlen_prefill is None and FLASH_DECODING:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
input_lengths.shape[0] + 1,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlen_k = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
|
||||||
),
|
|
||||||
input_lengths.cumsum(dim=-1),
|
|
||||||
]
|
|
||||||
).to(dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
cu_seqlen_q = None
|
|
||||||
cu_seqlen_k = input_lengths
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -472,8 +451,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -237,7 +237,6 @@ class MistralAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -299,7 +299,6 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -176,7 +176,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -215,7 +215,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -157,7 +157,6 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -225,7 +225,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
@ -349,7 +348,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -309,7 +309,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -263,7 +263,6 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -31,10 +31,12 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.globals import (
|
from text_generation_server.models.globals import (
|
||||||
MEM_POOL,
|
MEM_POOL,
|
||||||
FLASH_DECODING,
|
FLASH_DECODING,
|
||||||
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||||
@ -47,9 +49,6 @@ from text_generation_server.utils.import_utils import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
BLOCK_SIZE: int = (
|
|
||||||
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
|
|
||||||
)
|
|
||||||
|
|
||||||
# Will be set in init
|
# Will be set in init
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
@ -927,6 +926,7 @@ class FlashCausalLM(Model):
|
|||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths,
|
"input_lengths": input_lengths,
|
||||||
}
|
}
|
||||||
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
@ -1086,6 +1086,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
seqlen = Seqlen(input_lengths=input_lengths)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
@ -1096,7 +1097,7 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
@ -1172,6 +1173,7 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user