mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Merge branch 'huggingface:main' into vb/followup-doc-fixes
This commit is contained in:
commit
27daf69ea8
99
flake.lock
Normal file
99
flake.lock
Normal file
@ -0,0 +1,99 @@
|
||||
{
|
||||
"nodes": {
|
||||
"flake-compat": {
|
||||
"locked": {
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1723099294,
|
||||
"narHash": "sha256-kkijy6sXo/SOhFw7ZEfYHbj1FJHxoeetOVOn3qNHc5o=",
|
||||
"owner": "danieldk",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "45892b6ec142eaf300d777926983a433b5842c88",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "cudnn-9.3",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"tgi-nix": "tgi-nix"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"tgi-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1723188417,
|
||||
"narHash": "sha256-GXdFuRMU9N9W0CryUTJIWhJkGwQFHSR2EW5xR0ZyBjk=",
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"rev": "491db7e234ecf79513ddb94da6ecc14167b9c0b3",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
73
flake.nix
Normal file
73
flake.nix
Normal file
@ -0,0 +1,73 @@
|
||||
{
|
||||
inputs = {
|
||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
tgi-nix,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
config = {
|
||||
allowUnfree = true;
|
||||
cudaSupport = true;
|
||||
};
|
||||
pkgs = import nixpkgs {
|
||||
inherit config system;
|
||||
overlays = [ tgi-nix.overlay ];
|
||||
};
|
||||
in
|
||||
{
|
||||
devShells.default =
|
||||
with pkgs;
|
||||
mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
cargo
|
||||
openssl.dev
|
||||
pkg-config
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
venvShellHook
|
||||
pip
|
||||
|
||||
einops
|
||||
fbgemm-gpu
|
||||
flash-attn
|
||||
flash-attn-layer-norm
|
||||
flash-attn-rotary
|
||||
grpc-interceptor
|
||||
grpcio-reflection
|
||||
grpcio-status
|
||||
hf-transfer
|
||||
loguru
|
||||
marlin-kernels
|
||||
opentelemetry-api
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-grpc
|
||||
opentelemetry-semantic-conventions
|
||||
peft
|
||||
tokenizers
|
||||
torch
|
||||
transformers
|
||||
vllm
|
||||
]);
|
||||
|
||||
venvDir = "./.venv";
|
||||
|
||||
postVenv = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
'';
|
||||
postShellHook = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
'';
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
@ -1,10 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
if FLASH_DECODING:
|
||||
if FLASH_DECODING or FLASH_INFER:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
|
@ -1,6 +1,10 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||
from text_generation_server.models.globals import (
|
||||
FLASH_DECODING,
|
||||
BLOCK_SIZE,
|
||||
FLASH_INFER,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
@ -23,7 +27,7 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if FLASH_DECODING:
|
||||
if FLASH_DECODING or FLASH_INFER:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
@ -72,7 +76,16 @@ def paged_attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
if FLASH_DECODING:
|
||||
if FLASH_INFER:
|
||||
from text_generation_server.layers.attention.flash_infer import decode_state
|
||||
|
||||
return decode_state.get().forward(
|
||||
query.contiguous(),
|
||||
paged_kv_cache=(key_cache, value_cache),
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
)
|
||||
elif FLASH_DECODING:
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
@ -206,7 +219,32 @@ except ImportError:
|
||||
|
||||
SUPPORTS_WINDOWING = V2
|
||||
|
||||
if V2:
|
||||
if FLASH_INFER:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
from text_generation_server.layers.attention.flash_infer import prefill_state
|
||||
|
||||
return prefill_state.get().forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
causal=causal,
|
||||
window_left=window_size_left,
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
)
|
||||
|
||||
elif V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
|
164
server/text_generation_server/layers/attention/flash_infer.py
Normal file
164
server/text_generation_server/layers/attention/flash_infer.py
Normal file
@ -0,0 +1,164 @@
|
||||
from typing import Optional
|
||||
from contextvars import ContextVar
|
||||
from contextlib import contextmanager
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
|
||||
"prefill_state"
|
||||
)
|
||||
|
||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||
"decode_state"
|
||||
)
|
||||
|
||||
workspace: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_workspace(device):
|
||||
"""Get shared flashinfer workspace."""
|
||||
global workspace
|
||||
if workspace is None:
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
return workspace
|
||||
|
||||
|
||||
def create_prefill_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a prefill state."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_prefill_state(
|
||||
*,
|
||||
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
||||
cu_seqlens: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
query_dtype: str = "float16",
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`attention` function while the context manager is active.
|
||||
"""
|
||||
|
||||
token = prefill_state.set(state)
|
||||
try:
|
||||
state.begin_forward(
|
||||
qo_indptr=cu_seqlens,
|
||||
kv_indptr=cu_seqlens,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
prefill_state.reset(token)
|
||||
|
||||
|
||||
def create_decode_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
"""Create a decode state."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
use_cuda_graph=False,
|
||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
||||
)
|
||||
|
||||
|
||||
def create_decode_state_cuda_graphs(
|
||||
*,
|
||||
device: torch.device,
|
||||
block_tables: torch.Tensor,
|
||||
block_tables_ptr: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
"""
|
||||
Create a decode state for use with CUDA Graphs. `block_tables`,
|
||||
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
|
||||
therefore stored as part of the state.
|
||||
"""
|
||||
workspace_buffer = get_workspace(device)
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
use_cuda_graph=True,
|
||||
paged_kv_indices_buffer=block_tables,
|
||||
paged_kv_indptr_buffer=block_tables_ptr,
|
||||
paged_kv_last_page_len_buffer=last_page_len,
|
||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_decode_state(
|
||||
*,
|
||||
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
|
||||
input_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer decoding state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`paged_attention` function while the context manager is active.
|
||||
"""
|
||||
indptr = torch.zeros(
|
||||
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||
)
|
||||
# Round up to page size and then calculate the cumulative sum to get
|
||||
# the indices into the block table.
|
||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||
indptr[1:].cumsum_(-1)
|
||||
|
||||
# Get the lengths of the last page in a block.
|
||||
last_page_len = torch.empty(
|
||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||
)
|
||||
torch.sub(input_lengths, 1, out=last_page_len)
|
||||
last_page_len.remainder_(page_size)
|
||||
last_page_len += 1
|
||||
|
||||
token = decode_state.set(state)
|
||||
|
||||
try:
|
||||
state.begin_forward(
|
||||
indptr=indptr,
|
||||
indices=block_tables,
|
||||
last_page_len=last_page_len,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
page_size=page_size,
|
||||
q_data_type=query_dtype,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
decode_state.reset(token)
|
@ -1,3 +1,4 @@
|
||||
from contextlib import nullcontext
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
@ -15,7 +16,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
)
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import (
|
||||
MEM_POOL,
|
||||
FLASH_DECODING,
|
||||
FLASH_INFER,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
get_adapter_to_index,
|
||||
@ -907,6 +909,7 @@ class FlashCausalLM(Model):
|
||||
config.sliding_window = None
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.num_heads = config.num_attention_heads
|
||||
# Validation is done in the model itself
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||
@ -935,6 +938,21 @@ class FlashCausalLM(Model):
|
||||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
|
||||
if FLASH_INFER:
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
create_prefill_state,
|
||||
create_decode_state,
|
||||
)
|
||||
|
||||
self.prefill_state = create_prefill_state(device=device)
|
||||
|
||||
if not CUDA_GRAPHS:
|
||||
self.decode_state = create_decode_state(
|
||||
device=device,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
@ -972,7 +990,7 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if FLASH_DECODING:
|
||||
if FLASH_DECODING or FLASH_INFER:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
@ -1044,38 +1062,66 @@ class FlashCausalLM(Model):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
if FLASH_INFER:
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
||||
block_tables_ptr = torch.zeros(
|
||||
bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
state = create_decode_state_cuda_graphs(
|
||||
device=input_ids.device,
|
||||
block_tables=block_tables.view(-1),
|
||||
block_tables_ptr=block_tables_ptr,
|
||||
last_page_len=last_page_len,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
self.cuda_graphs[bs]["state"] = state
|
||||
else:
|
||||
state = None
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths_,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths=input_lengths,
|
||||
state=state,
|
||||
):
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=input_lengths_,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
@ -1295,23 +1341,28 @@ class FlashCausalLM(Model):
|
||||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, speculative_logits
|
||||
):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
@ -1325,8 +1376,16 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
state = cuda_graph.get("state")
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths=input_lengths,
|
||||
state=state,
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
@ -1698,3 +1757,39 @@ class FlashCausalLM(Model):
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
||||
def _forward_context(
|
||||
self,
|
||||
*,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
input_lengths: torch.Tensor,
|
||||
state: Optional[Any] = None,
|
||||
) -> ContextManager:
|
||||
if not FLASH_INFER:
|
||||
return nullcontext()
|
||||
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
use_decode_state,
|
||||
use_prefill_state,
|
||||
)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
return use_prefill_state(
|
||||
state=state if state is not None else self.prefill_state,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
)
|
||||
else:
|
||||
assert input_lengths is not None
|
||||
return use_decode_state(
|
||||
state=state if state is not None else self.decode_state,
|
||||
input_lengths=input_lengths,
|
||||
block_tables=block_tables.view(-1),
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
)
|
||||
|
@ -5,6 +5,10 @@ from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
|
||||
if FLASH_INFER:
|
||||
log_master(logger.info, "Using FLASH_INFER")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
if FLASH_DECODING:
|
||||
log_master(logger.info, "Using FLASH_DECODING")
|
||||
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user