diff --git a/Cargo.lock b/Cargo.lock index e535004e..6796212f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4243,6 +4243,7 @@ dependencies = [ "hf-hub", "nix 0.28.0", "once_cell", + "pyo3", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 032dc857..a783fadb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +pyo3 = { version = "0.22.2", features = ["auto-initialize"] } [profile.release] incremental = true diff --git a/flake.lock b/flake.lock index d811be5e..14e23b77 100644 --- a/flake.lock +++ b/flake.lock @@ -978,16 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1726743157, - "narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=", - "owner": "danieldk", - "repo": "tgi-nix", - "rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f", + "lastModified": 1727353315, + "narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=", + "owner": "huggingface", + "repo": "text-generation-inference-nix", + "rev": "1d42c4125ebafb87707118168995675cc5050b9d", "type": "github" }, "original": { - "owner": "danieldk", - "repo": "tgi-nix", + "owner": "huggingface", + "repo": "text-generation-inference-nix", "type": "github" } } diff --git a/flake.nix b/flake.nix index 260b2554..1b396453 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:danieldk/tgi-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { @@ -132,49 +132,12 @@ pre-commit ruff ]); - }; - impure = mkShell { - buildInputs = - [ - openssl.dev - pkg-config - (rust-bin.stable.latest.default.override { - extensions = [ - "rust-analyzer" - "rust-src" - ]; - }) - protobuf - ] - ++ (with python3.pkgs; [ - venvShellHook - docker - pip - ipdb - click - pyright - pytest - pytest-asyncio - redocly - ruff - syrupy - ]); + impure = callPackage ./nix/impure-shell.nix { inherit server; }; - inputsFrom = [ server ]; - - venvDir = "./.venv"; - - postVenvCreation = '' - unset SOURCE_DATE_EPOCH - ( cd server ; python -m pip install --no-dependencies -e . ) - ( cd clients/python ; python -m pip install --no-dependencies -e . ) - ''; - postShellHook = '' - unset SOURCE_DATE_EPOCH - export PATH=$PATH:~/.cargo/bin - ''; + impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix { + server = server.override { flash-attn = python3.pkgs.flash-attn-v1; }; }; }; diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index eb219423..033a9a04 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] } hf-hub = "0.3.2" nix = { version = "0.28.0", features = ["signal"] } once_cell = "1.19.0" +pyo3 = { workspace = true } serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.107" thiserror = "1.0.59" diff --git a/launcher/src/gpu.rs b/launcher/src/gpu.rs new file mode 100644 index 00000000..755d246a --- /dev/null +++ b/launcher/src/gpu.rs @@ -0,0 +1,26 @@ +use std::sync::LazyLock; + +pub static COMPUTE_CAPABILITY: LazyLock> = + LazyLock::new(get_cuda_capability); + +fn get_cuda_capability() -> Option<(usize, usize)> { + use pyo3::prelude::*; + + let py_get_capability = |py: Python| -> PyResult<(isize, isize)> { + let torch = py.import_bound("torch.cuda")?; + let get_device_capability = torch.getattr("get_device_capability")?; + get_device_capability.call0()?.extract() + }; + + match pyo3::Python::with_gil(py_get_capability) { + Ok((major, minor)) if major < 0 || minor < 0 => { + tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}"); + None + } + Ok((major, minor)) => Some((major as usize, minor as usize)), + Err(err) => { + tracing::warn!("Cannot determine GPU compute capability: {}", err); + None + } + } +} diff --git a/launcher/src/main.rs b/launcher/src/main.rs index deb18478..583220a6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -26,6 +26,7 @@ use thiserror::Error; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; +mod gpu; fn get_config( model_id: &str, @@ -65,6 +66,7 @@ fn get_config( } fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { + let compute_capability = *gpu::COMPUTE_CAPABILITY; let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); if let Some(config) = config { @@ -77,6 +79,13 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> prefix_caching = Some("0".to_string()); } } + + let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { + "paged" + } else { + "flashdecoding" + }; + match config.head_dim { Some(h) if h == 64 || h == 128 || h == 256 => { if lora_adapters.is_some() && prefix_caching.is_none() { @@ -89,10 +98,14 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> // flashinfer ? if attention.is_none() { tracing::info!( - "Forcing flash decoding because model {} requires it", + "Forcing attention to '{fallback_attention}' because model {} requires it", config.model_type.as_ref().unwrap() ); - attention = Some("flashdecoding".to_string()); + attention = Some(fallback_attention.to_string()); + } + if fallback_attention == "paged" && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention"); + prefix_caching = Some("0".to_string()); } } Some("t5") => {} @@ -101,8 +114,8 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } _ => { if attention.is_none() { - tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); - attention = Some("flashdecoding".to_string()); + tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching"); + attention = Some(fallback_attention.to_string()); } if prefix_caching.is_none() { prefix_caching = Some("0".to_string()); @@ -110,8 +123,10 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } } - let prefix_caching = prefix_caching.unwrap_or("true".to_string()); + let attention = attention.unwrap_or("flashinfer".to_string()); + let prefix_caching = prefix_caching.unwrap_or("true".to_string()); + (prefix_caching, attention) } diff --git a/nix/impure-shell.nix b/nix/impure-shell.nix new file mode 100644 index 00000000..a4dad4ba --- /dev/null +++ b/nix/impure-shell.nix @@ -0,0 +1,54 @@ +{ + mkShell, + openssl, + pkg-config, + protobuf, + python3, + pyright, + redocly, + ruff, + rust-bin, + server, +}: + +mkShell { + buildInputs = + [ + openssl.dev + pkg-config + (rust-bin.stable.latest.default.override { + extensions = [ + "rust-analyzer" + "rust-src" + ]; + }) + protobuf + pyright + redocly + ruff + ] + ++ (with python3.pkgs; [ + venvShellHook + docker + pip + ipdb + click + pytest + pytest-asyncio + syrupy + ]); + + inputsFrom = [ server ]; + + venvDir = "./.venv"; + + postVenvCreation = '' + unset SOURCE_DATE_EPOCH + ( cd server ; python -m pip install --no-dependencies -e . ) + ( cd clients/python ; python -m pip install --no-dependencies -e . ) + ''; + postShellHook = '' + unset SOURCE_DATE_EPOCH + export PATH=$PATH:~/.cargo/bin + ''; +} diff --git a/nix/server.nix b/nix/server.nix index 5921da7f..7406d563 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -13,6 +13,7 @@ flash-attn, flash-attn-layer-norm, flash-attn-rotary, + flash-attn-v1, grpc-interceptor, grpcio-reflection, grpcio-status, diff --git a/router/Cargo.toml b/router/Cargo.toml index 6a752db6..83d85327 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [ ] } csv = "1.3.0" ureq = "=2.9" -pyo3 = { version = "0.22.2", features = ["auto-initialize"] } +pyo3 = { workspace = true } [build-dependencies] diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index a2f97700..4f2b9807 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -18,16 +18,16 @@ elif SYSTEM == "rocm": attention, paged_attention, reshape_and_cache, - SUPPORTS_WINDOWING, PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, ) elif SYSTEM == "ipex": from .ipex import ( attention, paged_attention, reshape_and_cache, - SUPPORTS_WINDOWING, PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") @@ -37,7 +37,7 @@ __all__ = [ "attention", "paged_attention", "reshape_and_cache", - "SUPPORTS_WINDOWING", "PREFILL_IN_KV_CACHE", + "SUPPORTS_WINDOWING", "Seqlen", ] diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 6c645770..51af928d 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -287,16 +287,14 @@ elif V2: else: def attention( - q, - k, - v, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, softcap=None, ): if window_size_left != -1: @@ -338,14 +336,14 @@ else: k, v, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, - True, + causal, False, 0, None, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index de7d673f..646a763d 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -215,7 +215,6 @@ if ENGINE != "triton": "or install flash attention with `cd server && make install install-flash-attention`" ) from e else: - for idx in range(torch.cuda.device_count()): name = torch.cuda.get_device_name(idx) if "MI210" not in name and "MI250" not in name: diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d47bb104..44c015cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a77ec234..75e43d88 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,8 +27,8 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention,