diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index bd0717ce..b986a082 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER import torch from typing import Optional -if FLASH_DECODING: +if FLASH_DECODING or FLASH_INFER: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dff742dc..998d28e7 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,6 +1,10 @@ import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.models.globals import ( + FLASH_DECODING, + BLOCK_SIZE, + FLASH_INFER, +) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -23,7 +27,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -72,7 +76,16 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_DECODING: + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import decode_state + + return decode_state.get().forward( + query.contiguous(), + paged_kv_cache=(key_cache, value_cache), + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + elif FLASH_DECODING: max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -202,7 +215,32 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if V2: +if FLASH_INFER: + + def attention( + q, + k, + v, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + softcap=0.0, + ): + from text_generation_server.layers.attention.flash_infer import prefill_state + + return prefill_state.get().forward( + q, + k, + v, + causal=causal, + window_left=window_size_left, + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + +elif V2: def attention( q, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flash_infer.py new file mode 100644 index 00000000..56b53b2c --- /dev/null +++ b/server/text_generation_server/layers/attention/flash_infer.py @@ -0,0 +1,164 @@ +from typing import Optional +from contextvars import ContextVar +from contextlib import contextmanager + +import flashinfer +import torch + +prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( + "prefill_state" +) + +decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( + "decode_state" +) + +workspace: Optional[torch.Tensor] = None + + +def get_workspace(device): + """Get shared flashinfer workspace.""" + global workspace + if workspace is None: + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + return workspace + + +def create_prefill_state( + *, + device: torch.device, +): + """Create a prefill state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_state( + *, + state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, + cu_seqlens: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + token = prefill_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + kv_indptr=cu_seqlens, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_state.reset(token) + + +def create_decode_state( + *, + device: torch.device, + num_heads: int, + num_kv_heads: int, +): + """Create a decode state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=False, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +def create_decode_state_cuda_graphs( + *, + device: torch.device, + block_tables: torch.Tensor, + block_tables_ptr: torch.Tensor, + last_page_len: torch.Tensor, + num_heads: int, + num_kv_heads: int, +): + """ + Create a decode state for use with CUDA Graphs. `block_tables`, + `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are + therefore stored as part of the state. + """ + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=True, + paged_kv_indices_buffer=block_tables, + paged_kv_indptr_buffer=block_tables_ptr, + paged_kv_last_page_len_buffer=last_page_len, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +@contextmanager +def use_decode_state( + *, + state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, + input_lengths: torch.Tensor, + block_tables: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer decoding state to the given + `state` and parameters. This state will be used by all calls to the + `paged_attention` function while the context manager is active. + """ + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = decode_state.set(state) + + try: + state.begin_forward( + indptr=indptr, + indices=block_tables, + last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + page_size=page_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + decode_state.reset(token) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36bb2662..12aa7dcd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext import math import os import time @@ -15,7 +16,7 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Iterable, Optional, Tuple, List, Type, Dict +from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, FLASH_DECODING, + FLASH_INFER, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -907,6 +909,7 @@ class FlashCausalLM(Model): config.sliding_window = None self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -935,6 +938,21 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_prefill_state, + create_decode_state, + ) + + self.prefill_state = create_prefill_state(device=device) + + if not CUDA_GRAPHS: + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + super().__init__( model_id=model_id, model=model, @@ -972,7 +990,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: self.kv_cache = [ ( torch.empty( @@ -1044,38 +1062,66 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables.view(-1), + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + self.cuda_graphs[bs]["state"] = state + else: + state = None + torch.cuda.synchronize() # Run once outside to warmup - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths_, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): @@ -1295,23 +1341,28 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, + cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -1325,8 +1376,16 @@ class FlashCausalLM(Model): cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - # Replay the graph - cuda_graph["graph"].replay() + state = cuda_graph.get("state") + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + # Replay the graph + cuda_graph["graph"].replay() + # Slice output to the correct shape speculative_logits = ( cuda_graph["speculative_logits"][:bs] @@ -1698,3 +1757,39 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + def _forward_context( + self, + *, + block_tables: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + input_lengths: torch.Tensor, + state: Optional[Any] = None, + ) -> ContextManager: + if not FLASH_INFER: + return nullcontext() + + from text_generation_server.layers.attention.flash_infer import ( + use_decode_state, + use_prefill_state, + ) + + if cu_seqlen_prefill is not None: + return use_prefill_state( + state=state if state is not None else self.prefill_state, + cu_seqlens=cu_seqlen_prefill, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + ) + else: + assert input_lengths is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths, + block_tables=block_tables.view(-1), + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8d2431db..42b43c87 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,10 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} +if FLASH_INFER: + log_master(logger.info, "Using FLASH_INFER") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} @@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: log_master(logger.info, "Using FLASH_DECODING") + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: