mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Improve support for GPUs with capability < 8
- For models that cannot use flashinfer, use flash-attn v1 + paged attention for models with a compute capability older than 8. - Disable prefix caching when using paged attention. - When using flash-attn v1, pass the key/value, rather than the cache, since v1 cannot use block tables.
This commit is contained in:
parent
7efcb5e0ed
commit
bee5ee1f03
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4243,6 +4243,7 @@ dependencies = [
|
||||
"hf-hub",
|
||||
"nix 0.28.0",
|
||||
"once_cell",
|
||||
"pyo3",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -33,6 +33,7 @@ metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
|
||||
hf-hub = "0.3.2"
|
||||
nix = { version = "0.28.0", features = ["signal"] }
|
||||
once_cell = "1.19.0"
|
||||
pyo3 = { workspace = true }
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.59"
|
||||
|
22
launcher/src/gpu.rs
Normal file
22
launcher/src/gpu.rs
Normal file
@ -0,0 +1,22 @@
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub static COMPUTE_CAPABILITY: LazyLock<Option<(isize, isize)>> =
|
||||
LazyLock::new(get_cuda_capability);
|
||||
|
||||
fn get_cuda_capability() -> Option<(isize, isize)> {
|
||||
use pyo3::prelude::*;
|
||||
|
||||
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
|
||||
let torch = py.import_bound("torch.cuda")?;
|
||||
let get_device_capability = torch.getattr("get_device_capability")?;
|
||||
get_device_capability.call0()?.extract()
|
||||
};
|
||||
|
||||
match pyo3::Python::with_gil(py_get_capability) {
|
||||
Ok(capability) => Some(capability),
|
||||
Err(err) => {
|
||||
tracing::warn!("Cannot determine GPU compute capability: {}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
@ -26,6 +26,7 @@ use thiserror::Error;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||
|
||||
mod env_runtime;
|
||||
mod gpu;
|
||||
|
||||
fn get_config(
|
||||
model_id: &str,
|
||||
@ -65,6 +66,7 @@ fn get_config(
|
||||
}
|
||||
|
||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let compute_capability = *gpu::COMPUTE_CAPABILITY;
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
@ -77,6 +79,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
||||
"paged"
|
||||
} else {
|
||||
"flashdecoding"
|
||||
};
|
||||
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||
@ -89,10 +98,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
// flashinfer ?
|
||||
if attention.is_none() {
|
||||
tracing::info!(
|
||||
"Forcing flash decoding because model {} requires it",
|
||||
"Forcing attention to '{fallback_attention}' because model {} requires it",
|
||||
config.model_type.as_ref().unwrap()
|
||||
);
|
||||
attention = Some("flashdecoding".to_string());
|
||||
attention = Some(fallback_attention.to_string());
|
||||
}
|
||||
}
|
||||
Some("t5") => {}
|
||||
@ -101,8 +110,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
}
|
||||
_ => {
|
||||
if attention.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
attention = Some("flashdecoding".to_string());
|
||||
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
attention = Some(fallback_attention.to_string());
|
||||
}
|
||||
if prefix_caching.is_none() {
|
||||
prefix_caching = Some("0".to_string());
|
||||
@ -110,8 +119,17 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
}
|
||||
}
|
||||
}
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||
let prefix_caching = if attention == "paged"
|
||||
&& prefix_caching.is_none()
|
||||
&& compute_capability.is_some()
|
||||
{
|
||||
tracing::info!("Disabling prefix caching because it is not supported with 'flashinfer'");
|
||||
"false".to_string()
|
||||
} else {
|
||||
prefix_caching.unwrap_or("true".to_string())
|
||||
};
|
||||
|
||||
(prefix_caching, attention)
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
||||
] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
pyo3 = { workspace = true }
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -11,11 +11,24 @@ if SYSTEM == "cuda":
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
from .rocm import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
from .ipex import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
||||
@ -24,6 +37,7 @@ __all__ = [
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"Seqlen",
|
||||
]
|
||||
|
@ -287,16 +287,14 @@ elif V2:
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=None,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
@ -338,16 +336,22 @@ else:
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# Prefill in the cache with every kind of attention, unless we
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||
|
@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
|
||||
def attention(
|
||||
|
@ -13,6 +13,9 @@ _PARTITION_SIZE = 512
|
||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
except Exception as e:
|
||||
@ -156,7 +159,6 @@ if ENGINE != "triton":
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
else:
|
||||
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
name = torch.cuda.get_device_name(idx)
|
||||
if "MI210" not in name and "MI250" not in name:
|
||||
|
@ -39,6 +39,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
@ -327,8 +328,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -25,7 +25,6 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
@ -39,6 +38,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -25,12 +25,12 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -24,7 +24,7 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -38,6 +38,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -27,6 +27,7 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -41,6 +41,7 @@ from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -39,10 +39,10 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
@ -267,8 +267,8 @@ class MixtralAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -26,7 +26,6 @@ from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
@ -40,6 +39,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
||||
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -19,13 +19,13 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -17,11 +17,11 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -5,7 +5,6 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
TensorParallelColumnLinear,
|
||||
@ -13,6 +12,7 @@ from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.attention import (
|
||||
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,11 +18,11 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -39,6 +39,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
FastRMSNorm,
|
||||
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class Starcoder2Config(PretrainedConfig):
|
||||
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
Loading…
Reference in New Issue
Block a user