mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Removing flash decoding part so it gets merged.
This commit is contained in:
parent
be87c840b8
commit
91f55ea2b5
@ -70,17 +70,7 @@ impl Infer {
|
|||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
processor_config: HubProcessorConfig,
|
processor_config: HubProcessorConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
|
||||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
let block_size = if flashdecoding { 256 } else { 16 };
|
|
||||||
let block_size = std::env::var("BLOCK_SIZE")
|
|
||||||
.map(|b| b.parse().unwrap_or(block_size))
|
|
||||||
.unwrap_or(block_size);
|
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
@ -22,14 +21,7 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
if FLASH_DECODING:
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
@ -40,8 +32,7 @@ def paged_attention(
|
|||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_q: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
cu_seqlen_k: torch.Tensor,
|
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
@ -65,45 +56,15 @@ def paged_attention(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = cu_seqlen_k
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if FLASH_DECODING:
|
|
||||||
max_q = 1
|
|
||||||
max_k = max_s
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
|
|
||||||
flash_attn_2_cuda.varlen_fwd(
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
out,
|
|
||||||
cu_seqlen_q,
|
|
||||||
cu_seqlen_k,
|
|
||||||
None,
|
|
||||||
block_tables,
|
|
||||||
None,
|
|
||||||
max_q,
|
|
||||||
max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
-1,
|
|
||||||
0,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
|
||||||
)
|
|
||||||
if use_v1:
|
if use_v1:
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
@ -28,14 +27,7 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
if FLASH_DECODING:
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
@ -46,8 +38,7 @@ def paged_attention(
|
|||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_q: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
cu_seqlen_k: torch.Tensor,
|
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
@ -71,45 +62,15 @@ def paged_attention(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = cu_seqlen_k
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if FLASH_DECODING:
|
|
||||||
max_q = 1
|
|
||||||
max_k = max_s
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
|
|
||||||
flash_attn_2_cuda.varlen_fwd(
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
out,
|
|
||||||
cu_seqlen_q,
|
|
||||||
cu_seqlen_k,
|
|
||||||
None,
|
|
||||||
block_tables,
|
|
||||||
None,
|
|
||||||
max_q,
|
|
||||||
max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
-1,
|
|
||||||
0,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
|
||||||
)
|
|
||||||
if use_v1:
|
if use_v1:
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
|
@ -59,8 +59,7 @@ def paged_attention(
|
|||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_q: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
cu_seqlen_k: torch.Tensor,
|
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
query = query.contiguous()
|
query = query.contiguous()
|
||||||
@ -73,7 +72,7 @@ def paged_attention(
|
|||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
@ -3,9 +3,8 @@ import torch
|
|||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
|
|
||||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
BLOCK_SIZE: int = 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
@ -31,23 +30,6 @@ class CacheManager:
|
|||||||
else:
|
else:
|
||||||
x = self.block_size // element_size
|
x = self.block_size // element_size
|
||||||
|
|
||||||
if FLASH_DECODING:
|
|
||||||
self.kv_cache = [
|
|
||||||
(
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, self.block_size, num_heads, head_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, self.block_size, num_heads, head_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
|
@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
|
|||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -260,9 +259,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
@ -314,8 +312,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -389,9 +386,8 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
@ -404,9 +400,8 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -469,24 +464,6 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
if cu_seqlen_prefill is None and FLASH_DECODING:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
input_lengths.shape[0] + 1,
|
|
||||||
device=input_ids.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlen_k = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
|
||||||
),
|
|
||||||
input_lengths.cumsum(dim=-1),
|
|
||||||
]
|
|
||||||
).to(dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
cu_seqlen_q = None
|
|
||||||
cu_seqlen_k = input_lengths
|
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -496,9 +473,8 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
|
||||||
cu_seqlen_k,
|
|
||||||
slots,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -455,7 +455,6 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -253,7 +253,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -244,7 +244,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
|||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -134,8 +133,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
@ -178,8 +176,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -280,8 +277,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
@ -295,8 +291,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -363,23 +358,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
if cu_seqlen_prefill is None and FLASH_DECODING:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
input_lengths.shape[0] + 1,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlen_k = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(1,), device=input_lengths.device, dtype=input_lengths.dtype
|
|
||||||
),
|
|
||||||
input_lengths.cumsum(dim=-1),
|
|
||||||
]
|
|
||||||
).to(dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
cu_seqlen_q = None
|
|
||||||
cu_seqlen_k = input_lengths
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -392,8 +370,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
cu_seqlen_q,
|
input_lengths,
|
||||||
cu_seqlen_k,
|
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,8 +28,8 @@ from typing import Optional, List, Tuple
|
|||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
attention,
|
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -220,7 +220,6 @@ class MistralAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -299,7 +299,6 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -176,7 +176,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -215,7 +215,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -176,7 +176,6 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -198,7 +198,9 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
paged_attention.reshape_and_cache(
|
||||||
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -206,7 +208,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -217,7 +219,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention(
|
paged_attention.attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -225,7 +227,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
@ -349,7 +350,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -309,7 +309,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -263,7 +263,6 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -5,9 +5,6 @@ from loguru import logger
|
|||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
|
||||||
if FLASH_DECODING:
|
|
||||||
logger.info("Using FLASH_DECODING")
|
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
try:
|
try:
|
||||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||||
|
Loading…
Reference in New Issue
Block a user