mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Merge branch 'huggingface:main' into vb/followup-doc-fixes
This commit is contained in:
commit
27daf69ea8
99
flake.lock
Normal file
99
flake.lock
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"flake-compat": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1696426674,
|
||||||
|
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||||
|
"owner": "edolstra",
|
||||||
|
"repo": "flake-compat",
|
||||||
|
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "edolstra",
|
||||||
|
"repo": "flake-compat",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"flake-utils": {
|
||||||
|
"inputs": {
|
||||||
|
"systems": "systems"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1710146030,
|
||||||
|
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1723099294,
|
||||||
|
"narHash": "sha256-kkijy6sXo/SOhFw7ZEfYHbj1FJHxoeetOVOn3qNHc5o=",
|
||||||
|
"owner": "danieldk",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "45892b6ec142eaf300d777926983a433b5842c88",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "danieldk",
|
||||||
|
"ref": "cudnn-9.3",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-utils": "flake-utils",
|
||||||
|
"nixpkgs": [
|
||||||
|
"tgi-nix",
|
||||||
|
"nixpkgs"
|
||||||
|
],
|
||||||
|
"tgi-nix": "tgi-nix"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"systems": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1681028828,
|
||||||
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tgi-nix": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-compat": "flake-compat",
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1723188417,
|
||||||
|
"narHash": "sha256-GXdFuRMU9N9W0CryUTJIWhJkGwQFHSR2EW5xR0ZyBjk=",
|
||||||
|
"owner": "danieldk",
|
||||||
|
"repo": "tgi-nix",
|
||||||
|
"rev": "491db7e234ecf79513ddb94da6ecc14167b9c0b3",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "danieldk",
|
||||||
|
"repo": "tgi-nix",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
73
flake.nix
Normal file
73
flake.nix
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
{
|
||||||
|
inputs = {
|
||||||
|
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||||
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
|
};
|
||||||
|
outputs =
|
||||||
|
{
|
||||||
|
self,
|
||||||
|
nixpkgs,
|
||||||
|
flake-utils,
|
||||||
|
tgi-nix,
|
||||||
|
}:
|
||||||
|
flake-utils.lib.eachDefaultSystem (
|
||||||
|
system:
|
||||||
|
let
|
||||||
|
config = {
|
||||||
|
allowUnfree = true;
|
||||||
|
cudaSupport = true;
|
||||||
|
};
|
||||||
|
pkgs = import nixpkgs {
|
||||||
|
inherit config system;
|
||||||
|
overlays = [ tgi-nix.overlay ];
|
||||||
|
};
|
||||||
|
in
|
||||||
|
{
|
||||||
|
devShells.default =
|
||||||
|
with pkgs;
|
||||||
|
mkShell {
|
||||||
|
buildInputs =
|
||||||
|
[
|
||||||
|
cargo
|
||||||
|
openssl.dev
|
||||||
|
pkg-config
|
||||||
|
]
|
||||||
|
++ (with python3.pkgs; [
|
||||||
|
venvShellHook
|
||||||
|
pip
|
||||||
|
|
||||||
|
einops
|
||||||
|
fbgemm-gpu
|
||||||
|
flash-attn
|
||||||
|
flash-attn-layer-norm
|
||||||
|
flash-attn-rotary
|
||||||
|
grpc-interceptor
|
||||||
|
grpcio-reflection
|
||||||
|
grpcio-status
|
||||||
|
hf-transfer
|
||||||
|
loguru
|
||||||
|
marlin-kernels
|
||||||
|
opentelemetry-api
|
||||||
|
opentelemetry-exporter-otlp
|
||||||
|
opentelemetry-instrumentation-grpc
|
||||||
|
opentelemetry-semantic-conventions
|
||||||
|
peft
|
||||||
|
tokenizers
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
|
vllm
|
||||||
|
]);
|
||||||
|
|
||||||
|
venvDir = "./.venv";
|
||||||
|
|
||||||
|
postVenv = ''
|
||||||
|
unset SOURCE_DATE_EPOCH
|
||||||
|
'';
|
||||||
|
postShellHook = ''
|
||||||
|
unset SOURCE_DATE_EPOCH
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
@ -1,10 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING or FLASH_INFER:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
from text_generation_server.models.globals import (
|
||||||
|
FLASH_DECODING,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
FLASH_INFER,
|
||||||
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -23,7 +27,7 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING or FLASH_INFER:
|
||||||
shape = key_cache.shape
|
shape = key_cache.shape
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
@ -72,7 +76,16 @@ def paged_attention(
|
|||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if FLASH_DECODING:
|
if FLASH_INFER:
|
||||||
|
from text_generation_server.layers.attention.flash_infer import decode_state
|
||||||
|
|
||||||
|
return decode_state.get().forward(
|
||||||
|
query.contiguous(),
|
||||||
|
paged_kv_cache=(key_cache, value_cache),
|
||||||
|
logits_soft_cap=softcap,
|
||||||
|
sm_scale=softmax_scale,
|
||||||
|
)
|
||||||
|
elif FLASH_DECODING:
|
||||||
max_q = 1
|
max_q = 1
|
||||||
max_k = max_s
|
max_k = max_s
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
@ -206,7 +219,32 @@ except ImportError:
|
|||||||
|
|
||||||
SUPPORTS_WINDOWING = V2
|
SUPPORTS_WINDOWING = V2
|
||||||
|
|
||||||
if V2:
|
if FLASH_INFER:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
|
):
|
||||||
|
from text_generation_server.layers.attention.flash_infer import prefill_state
|
||||||
|
|
||||||
|
return prefill_state.get().forward(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
causal=causal,
|
||||||
|
window_left=window_size_left,
|
||||||
|
logits_soft_cap=softcap,
|
||||||
|
sm_scale=softmax_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif V2:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
|
164
server/text_generation_server/layers/attention/flash_infer.py
Normal file
164
server/text_generation_server/layers/attention/flash_infer.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
|
||||||
|
"prefill_state"
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||||
|
"decode_state"
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_workspace(device):
|
||||||
|
"""Get shared flashinfer workspace."""
|
||||||
|
global workspace
|
||||||
|
if workspace is None:
|
||||||
|
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
def create_prefill_state(
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Create a prefill state."""
|
||||||
|
workspace_buffer = get_workspace(device)
|
||||||
|
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_prefill_state(
|
||||||
|
*,
|
||||||
|
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
query_dtype: str = "float16",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
|
`state` and parameters. This state will be used by all calls to the
|
||||||
|
`attention` function while the context manager is active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token = prefill_state.set(state)
|
||||||
|
try:
|
||||||
|
state.begin_forward(
|
||||||
|
qo_indptr=cu_seqlens,
|
||||||
|
kv_indptr=cu_seqlens,
|
||||||
|
num_qo_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_size,
|
||||||
|
q_data_type=query_dtype,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
state.end_forward()
|
||||||
|
if token is not None:
|
||||||
|
prefill_state.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def create_decode_state(
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
):
|
||||||
|
"""Create a decode state."""
|
||||||
|
workspace_buffer = get_workspace(device)
|
||||||
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer,
|
||||||
|
kv_layout="NHD",
|
||||||
|
use_cuda_graph=False,
|
||||||
|
use_tensor_cores=num_heads // num_kv_heads > 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_decode_state_cuda_graphs(
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
block_tables_ptr: torch.Tensor,
|
||||||
|
last_page_len: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a decode state for use with CUDA Graphs. `block_tables`,
|
||||||
|
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
|
||||||
|
therefore stored as part of the state.
|
||||||
|
"""
|
||||||
|
workspace_buffer = get_workspace(device)
|
||||||
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer,
|
||||||
|
kv_layout="NHD",
|
||||||
|
use_cuda_graph=True,
|
||||||
|
paged_kv_indices_buffer=block_tables,
|
||||||
|
paged_kv_indptr_buffer=block_tables_ptr,
|
||||||
|
paged_kv_last_page_len_buffer=last_page_len,
|
||||||
|
use_tensor_cores=num_heads // num_kv_heads > 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_decode_state(
|
||||||
|
*,
|
||||||
|
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
page_size: int,
|
||||||
|
query_dtype: str = "float16",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager to set the active flashinfer decoding state to the given
|
||||||
|
`state` and parameters. This state will be used by all calls to the
|
||||||
|
`paged_attention` function while the context manager is active.
|
||||||
|
"""
|
||||||
|
indptr = torch.zeros(
|
||||||
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
# Round up to page size and then calculate the cumulative sum to get
|
||||||
|
# the indices into the block table.
|
||||||
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||||
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||||
|
indptr[1:].cumsum_(-1)
|
||||||
|
|
||||||
|
# Get the lengths of the last page in a block.
|
||||||
|
last_page_len = torch.empty(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
torch.sub(input_lengths, 1, out=last_page_len)
|
||||||
|
last_page_len.remainder_(page_size)
|
||||||
|
last_page_len += 1
|
||||||
|
|
||||||
|
token = decode_state.set(state)
|
||||||
|
|
||||||
|
try:
|
||||||
|
state.begin_forward(
|
||||||
|
indptr=indptr,
|
||||||
|
indices=block_tables,
|
||||||
|
last_page_len=last_page_len,
|
||||||
|
num_qo_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_size,
|
||||||
|
page_size=page_size,
|
||||||
|
q_data_type=query_dtype,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
state.end_forward()
|
||||||
|
if token is not None:
|
||||||
|
decode_state.reset(token)
|
@ -1,3 +1,4 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -15,7 +16,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.globals import (
|
from text_generation_server.models.globals import (
|
||||||
MEM_POOL,
|
MEM_POOL,
|
||||||
FLASH_DECODING,
|
FLASH_DECODING,
|
||||||
|
FLASH_INFER,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
@ -907,6 +909,7 @@ class FlashCausalLM(Model):
|
|||||||
config.sliding_window = None
|
config.sliding_window = None
|
||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
@ -935,6 +938,21 @@ class FlashCausalLM(Model):
|
|||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
|
if FLASH_INFER:
|
||||||
|
from text_generation_server.layers.attention.flash_infer import (
|
||||||
|
create_prefill_state,
|
||||||
|
create_decode_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
|
||||||
|
if not CUDA_GRAPHS:
|
||||||
|
self.decode_state = create_decode_state(
|
||||||
|
device=device,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
@ -972,7 +990,7 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING or FLASH_INFER:
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@ -1044,38 +1062,66 @@ class FlashCausalLM(Model):
|
|||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
if FLASH_INFER:
|
||||||
|
from text_generation_server.layers.attention.flash_infer import (
|
||||||
|
create_decode_state_cuda_graphs,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_tables_ptr = torch.zeros(
|
||||||
|
bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
|
state = create_decode_state_cuda_graphs(
|
||||||
|
device=input_ids.device,
|
||||||
|
block_tables=block_tables.view(-1),
|
||||||
|
block_tables_ptr=block_tables_ptr,
|
||||||
|
last_page_len=last_page_len,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
self.cuda_graphs[bs]["state"] = state
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# Run once outside to warmup
|
# Run once outside to warmup
|
||||||
self.model.forward(
|
with self._forward_context(
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=None,
|
|
||||||
kv_cache=self.kv_cache,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths_,
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
state=state,
|
||||||
prefill_cache_indices=None,
|
):
|
||||||
lm_head_indices=None,
|
self.model.forward(
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
|
||||||
logits, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths_,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
self.cuda_graphs[bs]["logits"] = logits
|
|
||||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
@ -1295,23 +1341,28 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
with self._forward_context(
|
||||||
logits, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
):
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
lm_head_indices=lm_head_indices,
|
logits, speculative_logits = self.model.forward(
|
||||||
adapter_data=adapter_data,
|
input_ids=input_ids,
|
||||||
)
|
position_ids=position_ids,
|
||||||
if batch.prefill_cache_indices is not None:
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
batch.prefill_cache_indices = None
|
kv_cache=kv_cache,
|
||||||
return logits, speculative_logits
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
@ -1325,8 +1376,16 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
|
||||||
# Replay the graph
|
state = cuda_graph.get("state")
|
||||||
cuda_graph["graph"].replay()
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
state=state,
|
||||||
|
):
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
speculative_logits = (
|
speculative_logits = (
|
||||||
cuda_graph["speculative_logits"][:bs]
|
cuda_graph["speculative_logits"][:bs]
|
||||||
@ -1698,3 +1757,39 @@ class FlashCausalLM(Model):
|
|||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
||||||
|
def _forward_context(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
state: Optional[Any] = None,
|
||||||
|
) -> ContextManager:
|
||||||
|
if not FLASH_INFER:
|
||||||
|
return nullcontext()
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention.flash_infer import (
|
||||||
|
use_decode_state,
|
||||||
|
use_prefill_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
return use_prefill_state(
|
||||||
|
state=state if state is not None else self.prefill_state,
|
||||||
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert input_lengths is not None
|
||||||
|
return use_decode_state(
|
||||||
|
state=state if state is not None else self.decode_state,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
block_tables=block_tables.view(-1),
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
page_size=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
@ -5,6 +5,10 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
|
||||||
|
if FLASH_INFER:
|
||||||
|
log_master(logger.info, "Using FLASH_INFER")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||||
@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
|||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
log_master(logger.info, "Using FLASH_DECODING")
|
log_master(logger.info, "Using FLASH_DECODING")
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user