Merge branch 'huggingface:main' into vb/followup-doc-fixes

This commit is contained in:
Vaibhav Srivastav 2024-08-09 14:08:37 +02:00 committed by GitHub
commit 27daf69ea8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 518 additions and 44 deletions

99
flake.lock Normal file
View File

@ -0,0 +1,99 @@
{
"nodes": {
"flake-compat": {
"locked": {
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1723099294,
"narHash": "sha256-kkijy6sXo/SOhFw7ZEfYHbj1FJHxoeetOVOn3qNHc5o=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "45892b6ec142eaf300d777926983a433b5842c88",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "cudnn-9.3",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": [
"tgi-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1723188417,
"narHash": "sha256-GXdFuRMU9N9W0CryUTJIWhJkGwQFHSR2EW5xR0ZyBjk=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "491db7e234ecf79513ddb94da6ecc14167b9c0b3",
"type": "github"
},
"original": {
"owner": "danieldk",
"repo": "tgi-nix",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

73
flake.nix Normal file
View File

@ -0,0 +1,73 @@
{
inputs = {
tgi-nix.url = "github:danieldk/tgi-nix";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
{
self,
nixpkgs,
flake-utils,
tgi-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
config = {
allowUnfree = true;
cudaSupport = true;
};
pkgs = import nixpkgs {
inherit config system;
overlays = [ tgi-nix.overlay ];
};
in
{
devShells.default =
with pkgs;
mkShell {
buildInputs =
[
cargo
openssl.dev
pkg-config
]
++ (with python3.pkgs; [
venvShellHook
pip
einops
fbgemm-gpu
flash-attn
flash-attn-layer-norm
flash-attn-rotary
grpc-interceptor
grpcio-reflection
grpcio-status
hf-transfer
loguru
marlin-kernels
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc
opentelemetry-semantic-conventions
peft
tokenizers
torch
transformers
vllm
]);
venvDir = "./.venv";
postVenv = ''
unset SOURCE_DATE_EPOCH
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
'';
};
}
);
}

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER
import torch
from typing import Optional
if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:
@dataclass
class Seqlen:

View File

@ -1,6 +1,10 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.models.globals import (
FLASH_DECODING,
BLOCK_SIZE,
FLASH_INFER,
)
from text_generation_server.layers.attention import Seqlen
from typing import Optional
@ -23,7 +27,7 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
@ -72,7 +76,16 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if FLASH_DECODING:
if FLASH_INFER:
from text_generation_server.layers.attention.flash_infer import decode_state
return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
@ -206,7 +219,32 @@ except ImportError:
SUPPORTS_WINDOWING = V2
if V2:
if FLASH_INFER:
def attention(
q,
k,
v,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flash_infer import prefill_state
return prefill_state.get().forward(
q,
k,
v,
causal=causal,
window_left=window_size_left,
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif V2:
def attention(
q,

View File

@ -0,0 +1,164 @@
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager
import flashinfer
import torch
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
"prefill_state"
)
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
workspace: Optional[torch.Tensor] = None
def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
if workspace is None:
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
return workspace
def create_prefill_state(
*,
device: torch.device,
):
"""Create a prefill state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state(
*,
device: torch.device,
num_heads: int,
num_kv_heads: int,
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
)
def create_decode_state_cuda_graphs(
*,
device: torch.device,
block_tables: torch.Tensor,
block_tables_ptr: torch.Tensor,
last_page_len: torch.Tensor,
num_heads: int,
num_kv_heads: int,
):
"""
Create a decode state for use with CUDA Graphs. `block_tables`,
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
)
@contextmanager
def use_decode_state(
*,
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
input_lengths: torch.Tensor,
block_tables: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer decoding state to the given
`state` and parameters. This state will be used by all calls to the
`paged_attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = decode_state.set(state)
try:
state.begin_forward(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
import math
import os
import time
@ -15,7 +16,7 @@ from transformers import (
AutoTokenizer,
GenerationConfig,
)
from typing import Iterable, Optional, Tuple, List, Type, Dict
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
MEM_POOL,
FLASH_DECODING,
FLASH_INFER,
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
@ -907,6 +909,7 @@ class FlashCausalLM(Model):
config.sliding_window = None
self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
@ -935,6 +938,21 @@ class FlashCausalLM(Model):
self.cuda_graphs = {}
self.kv_cache = []
if FLASH_INFER:
from text_generation_server.layers.attention.flash_infer import (
create_prefill_state,
create_decode_state,
)
self.prefill_state = create_prefill_state(device=device)
if not CUDA_GRAPHS:
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
super().__init__(
model_id=model_id,
model=model,
@ -972,7 +990,7 @@ class FlashCausalLM(Model):
else:
x = BLOCK_SIZE // element_size
if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:
self.kv_cache = [
(
torch.empty(
@ -1044,8 +1062,35 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
if FLASH_INFER:
from text_generation_server.layers.attention.flash_infer import (
create_decode_state_cuda_graphs,
)
block_tables_ptr = torch.zeros(
bs + 1, dtype=torch.int32, device=self.device
)
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs(
device=input_ids.device,
block_tables=block_tables.view(-1),
block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.cuda_graphs[bs]["state"] = state
else:
state = None
torch.cuda.synchronize()
# Run once outside to warmup
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths=input_lengths,
state=state,
):
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -1058,6 +1103,7 @@ class FlashCausalLM(Model):
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
@ -1295,6 +1341,11 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=input_lengths,
):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -1325,8 +1376,16 @@ class FlashCausalLM(Model):
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
state = cuda_graph.get("state")
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths=input_lengths,
state=state,
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
@ -1698,3 +1757,39 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)
def _forward_context(
self,
*,
block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: torch.Tensor,
state: Optional[Any] = None,
) -> ContextManager:
if not FLASH_INFER:
return nullcontext()
from text_generation_server.layers.attention.flash_infer import (
use_decode_state,
use_prefill_state,
)
if cu_seqlen_prefill is not None:
return use_prefill_state(
state=state if state is not None else self.prefill_state,
cu_seqlens=cu_seqlen_prefill,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
)
else:
assert input_lengths is not None
return use_decode_state(
state=state if state is not None else self.decode_state,
input_lengths=input_lengths,
block_tables=block_tables.view(-1),
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
)

View File

@ -5,6 +5,10 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
if FLASH_INFER:
log_master(logger.info, "Using FLASH_INFER")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
log_master(logger.info, "Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try: