Merge remote-tracking branch 'upstream/main' into rocm_6.2_updates

This commit is contained in:
Mohit Sharma 2024-09-27 15:36:12 +00:00
commit 473d9a892d
15 changed files with 134 additions and 75 deletions

1
Cargo.lock generated
View File

@ -4243,6 +4243,7 @@ dependencies = [
"hf-hub", "hf-hub",
"nix 0.28.0", "nix 0.28.0",
"once_cell", "once_cell",
"pyo3",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -33,6 +33,7 @@ metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] } minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[profile.release] [profile.release]
incremental = true incremental = true

View File

@ -978,16 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1726743157, "lastModified": 1727353315,
"narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=", "narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=",
"owner": "danieldk", "owner": "huggingface",
"repo": "tgi-nix", "repo": "text-generation-inference-nix",
"rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f", "rev": "1d42c4125ebafb87707118168995675cc5050b9d",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "huggingface",
"repo": "tgi-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix"; tgi-nix.url = "github:huggingface/text-generation-inference-nix";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {
@ -132,49 +132,12 @@
pre-commit pre-commit
ruff ruff
]); ]);
}; };
impure = mkShell { impure = callPackage ./nix/impure-shell.nix { inherit server; };
buildInputs =
[
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pyright
pytest
pytest-asyncio
redocly
ruff
syrupy
]);
inputsFrom = [ server ]; impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
}; };
}; };

View File

@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2" hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] } nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0" once_cell = "1.19.0"
pyo3 = { workspace = true }
serde = { version = "1.0.188", features = ["derive"] } serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.59" thiserror = "1.0.59"

26
launcher/src/gpu.rs Normal file
View File

@ -0,0 +1,26 @@
use std::sync::LazyLock;
pub static COMPUTE_CAPABILITY: LazyLock<Option<(usize, usize)>> =
LazyLock::new(get_cuda_capability);
fn get_cuda_capability() -> Option<(usize, usize)> {
use pyo3::prelude::*;
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
let torch = py.import_bound("torch.cuda")?;
let get_device_capability = torch.getattr("get_device_capability")?;
get_device_capability.call0()?.extract()
};
match pyo3::Python::with_gil(py_get_capability) {
Ok((major, minor)) if major < 0 || minor < 0 => {
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
None
}
Ok((major, minor)) => Some((major as usize, minor as usize)),
Err(err) => {
tracing::warn!("Cannot determine GPU compute capability: {}", err);
None
}
}
}

View File

@ -26,6 +26,7 @@ use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter}; use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
mod gpu;
fn get_config( fn get_config(
model_id: &str, model_id: &str,
@ -65,6 +66,7 @@ fn get_config(
} }
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) { fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = *gpu::COMPUTE_CAPABILITY;
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config { if let Some(config) = config {
@ -77,6 +79,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
} }
} }
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
"paged"
} else {
"flashdecoding"
};
match config.head_dim { match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => { Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() { if lora_adapters.is_some() && prefix_caching.is_none() {
@ -89,10 +98,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
// flashinfer ? // flashinfer ?
if attention.is_none() { if attention.is_none() {
tracing::info!( tracing::info!(
"Forcing flash decoding because model {} requires it", "Forcing attention to '{fallback_attention}' because model {} requires it",
config.model_type.as_ref().unwrap() config.model_type.as_ref().unwrap()
); );
attention = Some("flashdecoding".to_string()); attention = Some(fallback_attention.to_string());
}
if fallback_attention == "paged" && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
prefix_caching = Some("0".to_string());
} }
} }
Some("t5") => {} Some("t5") => {}
@ -101,8 +114,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
_ => { _ => {
if attention.is_none() { if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string()); attention = Some(fallback_attention.to_string());
} }
if prefix_caching.is_none() { if prefix_caching.is_none() {
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
@ -110,8 +123,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
} }
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string()); let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
(prefix_caching, attention) (prefix_caching, attention)
} }

54
nix/impure-shell.nix Normal file
View File

@ -0,0 +1,54 @@
{
mkShell,
openssl,
pkg-config,
protobuf,
python3,
pyright,
redocly,
ruff,
rust-bin,
server,
}:
mkShell {
buildInputs =
[
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
pyright
redocly
ruff
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pytest
pytest-asyncio
syrupy
]);
inputsFrom = [ server ];
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
}

View File

@ -13,6 +13,7 @@
flash-attn, flash-attn,
flash-attn-layer-norm, flash-attn-layer-norm,
flash-attn-rotary, flash-attn-rotary,
flash-attn-v1,
grpc-interceptor, grpc-interceptor,
grpcio-reflection, grpcio-reflection,
grpcio-status, grpcio-status,

View File

@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] } ] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] } pyo3 = { workspace = true }
[build-dependencies] [build-dependencies]

View File

@ -18,16 +18,16 @@ elif SYSTEM == "rocm":
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE, PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
from .ipex import ( from .ipex import (
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE, PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
) )
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
@ -37,7 +37,7 @@ __all__ = [
"attention", "attention",
"paged_attention", "paged_attention",
"reshape_and_cache", "reshape_and_cache",
"SUPPORTS_WINDOWING",
"PREFILL_IN_KV_CACHE", "PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"Seqlen", "Seqlen",
] ]

View File

@ -287,16 +287,14 @@ elif V2:
else: else:
def attention( def attention(
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
key_cache: torch.Tensor, seqlen: Seqlen,
value_cache: torch.Tensor, block_tables: torch.Tensor,
cu_seqlens, softmax_scale: float,
max_s, window_size_left: int = -1,
softmax_scale, causal: bool = True,
window_size_left=-1,
causal=None,
softcap=None, softcap=None,
): ):
if window_size_left != -1: if window_size_left != -1:
@ -338,14 +336,14 @@ else:
k, k,
v, v,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
max_s, seqlen.max_q,
max_s, seqlen.max_k,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
False, False,
0, 0,
None, None,

View File

@ -215,7 +215,6 @@ if ENGINE != "triton":
"or install flash attention with `cd server && make install install-flash-attention`" "or install flash attention with `cd server && make install install-flash-attention`"
) from e ) from e
else: else:
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx) name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name: if "MI210" not in name and "MI250" not in name:

View File

@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,

View File

@ -27,8 +27,8 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,