From 38773453ae0d29fba3dc79a38d589ebdc5451093 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 19 Aug 2024 09:28:38 +0200 Subject: [PATCH 01/32] nix: update to CUDA 12.4 (#2429) * Update to CUDA 12.4 * poetry2nix: follow tgi-nix nixpkgs --- flake.lock | 47 +++++++++++++++++------------------------------ flake.nix | 5 ++++- 2 files changed, 21 insertions(+), 31 deletions(-) diff --git a/flake.lock b/flake.lock index 7c772377..d0c2adbc 100644 --- a/flake.lock +++ b/flake.lock @@ -739,32 +739,16 @@ }, "nixpkgs_6": { "locked": { - "lastModified": 1719763542, - "narHash": "sha256-mXkOj9sJ0f69Nkc2dGGOWtof9d1YNY8Le/Hia3RN+8Q=", - "owner": "NixOS", + "lastModified": 1723912943, + "narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=", + "owner": "danieldk", "repo": "nixpkgs", - "rev": "e6cdd8a11b26b4d60593733106042141756b54a3", + "rev": "b82cdca86dbb30013b76c4b55d48806476820a5c", "type": "github" }, "original": { - "owner": "NixOS", - "ref": "nixos-unstable-small", - "repo": "nixpkgs", - "type": "github" - } - }, - "nixpkgs_7": { - "locked": { - "lastModified": 1723418128, - "narHash": "sha256-k1pEqsnB6ikZyasXbtV6A9akPZMKlsyENPDUA6PXoJo=", - "owner": "nixos", - "repo": "nixpkgs", - "rev": "129f579cbb5b4c1ad258fd96bdfb78eb14802727", - "type": "github" - }, - "original": { - "owner": "nixos", - "ref": "nixos-unstable-small", + "owner": "danieldk", + "ref": "cuda-12.4", "repo": "nixpkgs", "type": "github" } @@ -773,16 +757,19 @@ "inputs": { "flake-utils": "flake-utils_7", "nix-github-actions": "nix-github-actions", - "nixpkgs": "nixpkgs_6", + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ], "systems": "systems_8", "treefmt-nix": "treefmt-nix" }, "locked": { - "lastModified": 1723512448, - "narHash": "sha256-VSTtxGKre1p6zd6ACuBmfDcR+BT9+ml8Y3KrSbfGFYU=", + "lastModified": 1723854676, + "narHash": "sha256-+BrHfNuXrqeE7PoV6xDaoh0joYiJkvTTCIV0fFR3THw=", "owner": "nix-community", "repo": "poetry2nix", - "rev": "ed52f844c4dd04dde45550c3189529854384124e", + "rev": "d650118bce34c0238b9b54f23f7f173f9e4db867", "type": "github" }, "original": { @@ -1048,14 +1035,14 @@ "tgi-nix": { "inputs": { "flake-compat": "flake-compat_4", - "nixpkgs": "nixpkgs_7" + "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1723532088, - "narHash": "sha256-6h/P/BkFDw8txlikonKXp5IbluHSPhHJTQRftJLkbLQ=", + "lastModified": 1723973328, + "narHash": "sha256-q5FmW4YFQcRb6fXHnrxL0uno6xcw9dcg+pFBbVM1xeQ=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "32335a37ce0f703bab901baf7b74eb11e9972d5f", + "rev": "d2038f36589a8a179834e5771ffd081620ba94c3", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index cf05746a..6bfe74e8 100644 --- a/flake.nix +++ b/flake.nix @@ -8,7 +8,10 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; - poetry2nix.url = "github:nix-community/poetry2nix"; + poetry2nix = { + url = "github:nix-community/poetry2nix"; + inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + }; rust-overlay = { url = "github:oxalica/rust-overlay"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; From b70ae0969f11bae03a3c6194fc8c592a1d8a65b3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 20 Aug 2024 11:15:30 +0200 Subject: [PATCH 02/32] Prefix caching (#2402) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prefix caching WIP * Fixing prefix attention. * Fixing flashinfer import. * Fixing black. * Fixing medusa (still wrong outputs, but functional). * Just medusa values now. * Fixing medusa without prefix caching. * Fixing prefix caching. * Medusa requires reshaping. * Removing the logs. * Remove router.nix * Fixup: - Remove logs - Disable VLMs (they do not work) - Disable prefix caching when user wants prefill logprobs. * Update flake.lock --------- Co-authored-by: Daniël de Kok --- backends/v3/src/queue.rs | 13 +- backends/v3/src/radix.rs | 1 + flake.lock | 6 +- flake.nix | 1 + .../layers/attention/__init__.py | 7 +- .../layers/attention/cuda.py | 27 +- .../{flash_infer.py => flashinfer.py} | 76 ++++++ .../text_generation_server/layers/medusa.py | 2 + .../custom_modeling/flash_cohere_modeling.py | 2 + .../custom_modeling/flash_dbrx_modeling.py | 2 + .../flash_deepseek_v2_modeling.py | 2 + .../custom_modeling/flash_gemma2_modeling.py | 2 + .../custom_modeling/flash_gemma_modeling.py | 2 + .../custom_modeling/flash_gpt2_modeling.py | 2 + .../custom_modeling/flash_llama_modeling.py | 2 + .../custom_modeling/flash_mistral_modeling.py | 2 + .../custom_modeling/flash_mixtral_modeling.py | 2 + .../custom_modeling/flash_neox_modeling.py | 2 + .../custom_modeling/flash_phi_modeling.py | 2 + .../custom_modeling/flash_qwen2_modeling.py | 2 + .../custom_modeling/flash_rw_modeling.py | 4 + .../flash_santacoder_modeling.py | 2 + .../flash_starcoder2_modeling.py | 2 + .../models/flash_causal_lm.py | 252 ++++++++++++++---- .../text_generation_server/models/globals.py | 6 +- .../models/vlm_causal_lm.py | 89 +++++-- 26 files changed, 405 insertions(+), 107 deletions(-) rename server/text_generation_server/layers/attention/{flash_infer.py => flashinfer.py} (65%) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 0fb05a98..faa57c11 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -316,10 +316,15 @@ impl State { + self.speculate - 1; - match block_allocator - .allocate(tokens, entry.request.input_ids.clone()) - .await - { + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() + }; + + match block_allocator.allocate(tokens, input_ids).await { None => { // Entry is over budget // Add it back to the front diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index ef963532..5bac1a31 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -205,6 +205,7 @@ pub struct RadixTrie { /// call that a real time lookup would require. time: u64, } + impl Default for RadixTrie { fn default() -> Self { Self::new() diff --git a/flake.lock b/flake.lock index d0c2adbc..cd5d6d2a 100644 --- a/flake.lock +++ b/flake.lock @@ -900,11 +900,11 @@ ] }, "locked": { - "lastModified": 1723515680, - "narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=", + "lastModified": 1723602049, + "narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3", + "rev": "ea0bf33a11a26a62c60123c49d96011da396602c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 6bfe74e8..299e6b3d 100644 --- a/flake.nix +++ b/flake.nix @@ -84,6 +84,7 @@ grpcio-status grpcio-tools hf-transfer + ipdb loguru mamba-ssm marlin-kernels diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index f9b1715e..56fc5319 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -6,7 +6,12 @@ from .common import Seqlen if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": - from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .cuda import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "ipex": diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 8703eb94..b3b7ea4f 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -76,7 +76,7 @@ def paged_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import decode_state + from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( query.contiguous(), @@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2 if ATTENTION == "flashinfer": def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -231,14 +233,15 @@ if ATTENTION == "flashinfer": causal=True, softcap=0.0, ): - from text_generation_server.layers.attention.flash_infer import prefill_state + assert window_size_left == -1, "Windowing is not supported with flash infer" + from text_generation_server.layers.attention.flashinfer import ( + prefill_with_paged_kv_state, + ) - return prefill_state.get().forward( - q, - k, - v, + return prefill_with_paged_kv_state.get().forward( + q.contiguous(), causal=causal, - window_left=window_size_left, + paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, ) @@ -249,6 +252,8 @@ elif V2: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -289,6 +294,8 @@ else: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flashinfer.py similarity index 65% rename from server/text_generation_server/layers/attention/flash_infer.py rename to server/text_generation_server/layers/attention/flashinfer.py index 56b53b2c..e1ef62c5 100644 --- a/server/text_generation_server/layers/attention/flash_infer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con "prefill_state" ) +prefill_with_paged_kv_state: ContextVar[ + flashinfer.BatchPrefillWithPagedKVCacheWrapper +] = ContextVar("prefill_with_paged_kv_state") + decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( "decode_state" ) @@ -24,6 +28,78 @@ def get_workspace(device): return workspace +def create_prefill_with_paged_kv_state( + *, + device: torch.device, +): + """Create a prefill state that uses the KV cache.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_with_paged_kv_state( + *, + state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, + block_tables: torch.Tensor, + cu_seqlens: torch.Tensor, + input_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + if page_size == 1: + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + else: + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = prefill_with_paged_kv_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + paged_kv_indptr=indptr, + paged_kv_indices=block_tables, + paged_kv_last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + page_size=page_size, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_with_paged_kv_state.reset(token) + + def create_prefill_state( *, device: torch.device, diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 7579ccdb..139c4dc2 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module): ) def forward(self, x): + if not self.heads: + return None speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) return speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e02a31d9..1eb8c6c3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d3d1d1ef..fc0dca5b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0905d3c2..b25becd5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 54d212e6..faf0f325 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 178efadb..33738a59 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a19cff8c..d30b5a0a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9ea19a87..3253d2dc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dda53ff3..5a150267 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 85431c6c..ad426ffe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 67237d5c..b684e035 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0], qkv[:, 1], qkv[:, 2], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9e18348..efe27c13 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 865cc85d..879b8abd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 10f995a3..c72a9b90 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, @@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module): query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c2676782..109304be 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module): query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index e562eb89..200d4ef0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5e2fd20a..dd4203e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,6 +43,7 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, + PREFIX_CACHING, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor max_seqlen: int @@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + # Prefixes + prefix_ids: List[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] all_input_ids = [] + prefix_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True @@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 + cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] + prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch): ): tokenized_input = tokenized_input[1:] + orig_input_length = len(tokenized_input) + + if PREFIX_CACHING: + prefix_len = r.prefix_len + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 + else: + prefix_len = 0 + + prefix_ids.append(tokenized_input[:prefix_len]) + tokenized_input = tokenized_input[prefix_len:] + input_length = len(tokenized_input) input_lengths.append(input_length) @@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + request_position_ids = torch.arange( + prefix_len, orig_input_length, dtype=torch.int32 + ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs @@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch): # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to blocks. + block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to slots. We don't need slots for the + # cached prefix (if present). + slot_tokens = input_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] @@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots + request_slots = r.slots[ + prefix_len: #: orig_input_length + max_new_tokens + speculative_length + ] block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) + + slots.extend(request_slots) + prefix_lens.append(prefix_len) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) + start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length - cumulative_max_length += total_tokens + cumulative_slot_tokens += slot_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch): ) slots = torch.tensor(slots, dtype=torch.int64, device=device) + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] + prefix_lens = [] prefix_offsets = [] read_offsets = [] @@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] + prefix_len = self.prefix_lens[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) + prefix_ids.append(self.prefix_ids[idx]) input_lengths.append(request_input_length) + prefix_lens.append(prefix_len) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] + prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( @@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) + prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) @@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] + prefix_lens = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] prefix_offsets = [] @@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) + prefix_lens.extend(batch.prefix_lens) all_input_ids.extend(batch.all_input_ids) + prefix_ids.extend(batch.prefix_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) @@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch): slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, @@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -970,19 +1028,22 @@ class FlashCausalLM(Model): self.kv_cache = [] if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_prefill_state, create_decode_state, + create_prefill_with_paged_kv_state, ) self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) - if not CUDA_GRAPHS: - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) super().__init__( model_id=model_id, @@ -1074,12 +1135,23 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) + input_lengths = [max_s] * bs + prefix_lengths = [0] * bs + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) + prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lengths, + ) self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1087,14 +1159,14 @@ class FlashCausalLM(Model): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "input_lengths": input_lengths, + "input_lengths": input_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, ) @@ -1104,7 +1176,7 @@ class FlashCausalLM(Model): last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) state = create_decode_state_cuda_graphs( device=input_ids.device, - block_tables=block_tables.view(-1), + block_tables=block_tables, block_tables_ptr=block_tables_ptr, last_page_len=last_page_len, num_heads=self.num_heads, @@ -1120,7 +1192,10 @@ class FlashCausalLM(Model): block_tables=block_tables, cu_seqlen_prefill=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, state=state, + prefix_lens=prefix_lengths, + prefix_lens_tensor=prefix_lengths_tensor, ): self.model.forward( input_ids=input_ids, @@ -1138,7 +1213,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1146,7 +1221,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_tensor, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1334,6 +1409,9 @@ class FlashCausalLM(Model): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -1354,6 +1432,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1372,10 +1451,20 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths=input_lengths, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, ): input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( @@ -1399,20 +1488,32 @@ class FlashCausalLM(Model): # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) - state = cuda_graph.get("state") with self._forward_context( - block_tables=block_tables, + block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - input_lengths=input_lengths, - state=state, + input_lengths=batch.input_lengths, + input_lengths_tensor=cuda_graph["input_lengths"], + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + state=cuda_graph.get("state"), ): # Replay the graph cuda_graph["graph"].replay() @@ -1610,6 +1711,7 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, + batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -1627,6 +1729,7 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, + prefix_ids, do_sample, seed, top_n_tokens, @@ -1701,18 +1804,18 @@ class FlashCausalLM(Model): out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] + request_prefill_logprobs = ( + [float("nan")] * (len(prefix_ids) + 1) + ) + prefill_logprobs[out_start_index : out_end_index - 1] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, + prefix_ids + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = Tokens( - prefill_token_ids, + prefix_ids + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], @@ -1794,33 +1897,68 @@ class FlashCausalLM(Model): *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths: torch.Tensor, + input_lengths: List[int], + input_lengths_tensor: torch.Tensor, + prefix_lens: List[int], + prefix_lens_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( use_decode_state, - use_prefill_state, + use_prefill_with_paged_kv_state, ) + # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + if cu_seqlen_prefill is not None: - return use_prefill_state( - state=state if state is not None else self.prefill_state, + return use_prefill_with_paged_kv_state( + state=( + state if state is not None else self.prefill_with_paged_kv_state + ), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - ) - else: - assert input_lengths is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths, - block_tables=block_tables.view(-1), + input_lengths=input_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, ) + else: + assert input_lengths_tensor is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths_tensor, + block_tables=block_tables, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) + + +def block_tables_to_ragged( + *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(prefix_lens) + + total_len = sum(input_lengths) + sum(prefix_lens) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + offset = 0 + for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): + seq_len = prefix_len + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index abc35421..d5133f5e 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,8 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) -log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") - +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} +log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( @@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer": else: BLOCK_SIZE = 16 - cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7de54aa4..2ed1a119 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + block_tables_to_ragged, ) +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM): trust_remote_code: bool, **kwargs, ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") if processor_kwargs is None: processor_kwargs = {} self.processor = processor_class.from_pretrained( @@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -349,43 +357,68 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) # Replay the graph cuda_graph["graph"].replay() From f5f11b797e70b2232632d410273c5c4418475dd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 20 Aug 2024 22:07:33 +0200 Subject: [PATCH 03/32] nix: add pure server to flake, add both pure and impure devshells (#2430) * nix: pure server and support both pure and impure devShells * nix: remove unused poetry2nix input It is not wired up and we now have a pure server. * nix: add ipdb to impure devshell --- flake.lock | 115 ------------------------------------------------- flake.nix | 64 +++++++++++---------------- nix/server.nix | 105 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 155 deletions(-) create mode 100644 nix/server.nix diff --git a/flake.lock b/flake.lock index cd5d6d2a..c20dd98a 100644 --- a/flake.lock +++ b/flake.lock @@ -492,24 +492,6 @@ "type": "github" } }, - "flake-utils_7": { - "inputs": { - "systems": "systems_7" - }, - "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, "gitignore": { "inputs": { "nixpkgs": [ @@ -594,27 +576,6 @@ "type": "github" } }, - "nix-github-actions": { - "inputs": { - "nixpkgs": [ - "poetry2nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1703863825, - "narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=", - "owner": "nix-community", - "repo": "nix-github-actions", - "rev": "5163432afc817cf8bd1f031418d1869e4c9d5547", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "nix-github-actions", - "type": "github" - } - }, "nix-test-runner": { "flake": false, "locked": { @@ -753,31 +714,6 @@ "type": "github" } }, - "poetry2nix": { - "inputs": { - "flake-utils": "flake-utils_7", - "nix-github-actions": "nix-github-actions", - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ], - "systems": "systems_8", - "treefmt-nix": "treefmt-nix" - }, - "locked": { - "lastModified": 1723854676, - "narHash": "sha256-+BrHfNuXrqeE7PoV6xDaoh0joYiJkvTTCIV0fFR3THw=", - "owner": "nix-community", - "repo": "poetry2nix", - "rev": "d650118bce34c0238b9b54f23f7f173f9e4db867", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "poetry2nix", - "type": "github" - } - }, "pre-commit-hooks": { "inputs": { "flake-compat": [ @@ -887,7 +823,6 @@ "tgi-nix", "nixpkgs" ], - "poetry2nix": "poetry2nix", "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } @@ -1003,35 +938,6 @@ "type": "github" } }, - "systems_7": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - }, - "systems_8": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "id": "systems", - "type": "indirect" - } - }, "tgi-nix": { "inputs": { "flake-compat": "flake-compat_4", @@ -1050,27 +956,6 @@ "repo": "tgi-nix", "type": "github" } - }, - "treefmt-nix": { - "inputs": { - "nixpkgs": [ - "poetry2nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1719749022, - "narHash": "sha256-ddPKHcqaKCIFSFc/cvxS14goUhCOAwsM1PbMr0ZtHMg=", - "owner": "numtide", - "repo": "treefmt-nix", - "rev": "8df5ff62195d4e67e2264df0b7f5e8c9995fd0bd", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "treefmt-nix", - "type": "github" - } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 299e6b3d..adc70fd1 100644 --- a/flake.nix +++ b/flake.nix @@ -8,10 +8,6 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; - poetry2nix = { - url = "github:nix-community/poetry2nix"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; - }; rust-overlay = { url = "github:oxalica/rust-overlay"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; @@ -26,7 +22,6 @@ flake-utils, rust-overlay, tgi-nix, - poetry2nix, }: flake-utils.lib.eachDefaultSystem ( system: @@ -47,14 +42,28 @@ tgi-nix.overlay ]; }; - inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; - text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; + launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override { + inherit crateOverrides; + }; + router = cargoNix.workspaceMembers.text-generation-router-v3.build.override { + inherit crateOverrides; + }; + server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; }; in { - devShells.default = - with pkgs; - mkShell { + devShells = with pkgs; rec { + default = pure; + + pure = mkShell { + buildInputs = [ + launcher + router + server + ]; + }; + + impure = mkShell { buildInputs = [ openssl.dev @@ -65,43 +74,16 @@ "rust-src" ]; }) + protobuf ] ++ (with python3.pkgs; [ venvShellHook pip - - causal-conv1d - click - einops - exllamav2 - fbgemm-gpu - flashinfer - flash-attn - flash-attn-layer-norm - flash-attn-rotary - grpc-interceptor - grpcio-reflection - grpcio-status - grpcio-tools - hf-transfer ipdb - loguru - mamba-ssm - marlin-kernels - opentelemetry-api - opentelemetry-exporter-otlp - opentelemetry-instrumentation-grpc - opentelemetry-semantic-conventions - peft - tokenizers - torch - transformers - vllm - - (cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; }) - (cargoNix.workspaceMembers.text-generation-router-v3.build.override { inherit crateOverrides; }) ]); + inputsFrom = [ server ]; + venvDir = "./.venv"; postVenv = '' @@ -109,8 +91,10 @@ ''; postShellHook = '' unset SOURCE_DATE_EPOCH + export PATH=$PATH:~/.cargo/bin ''; }; + }; } ); } diff --git a/nix/server.nix b/nix/server.nix new file mode 100644 index 00000000..ff40757a --- /dev/null +++ b/nix/server.nix @@ -0,0 +1,105 @@ +{ + nix-filter, + buildPythonPackage, + poetry-core, + mypy-protobuf, + causal-conv1d, + einops, + exllamav2, + fbgemm-gpu, + flashinfer, + flash-attn, + flash-attn-layer-norm, + flash-attn-rotary, + grpc-interceptor, + grpcio-reflection, + grpcio-status, + grpcio-tools, + hf-transfer, + loguru, + mamba-ssm, + marlin-kernels, + opentelemetry-api, + opentelemetry-exporter-otlp, + opentelemetry-instrumentation-grpc, + opentelemetry-semantic-conventions, + peft, + safetensors, + tokenizers, + sentencepiece, + transformers, + typer, + vllm, +}: + +let + filter = nix-filter.lib; +in +buildPythonPackage { + name = "text-generation-server"; + + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi"))) + "server/pyproject.toml" + (and (inDirectory "proto/v3") (matchExt "proto")) + ]; + }; + + pyproject = true; + + build-system = [ poetry-core ]; + + nativeBuildInputs = [ mypy-protobuf ]; + + pythonRelaxDeps = [ + "einops" + "huggingface-hub" + "loguru" + "opentelemetry-instrumentation-grpc" + "sentencepiece" + "typer" + ]; + + pythonRemoveDeps = [ "scipy" ]; + + dependencies = [ + causal-conv1d + einops + exllamav2 + fbgemm-gpu + flashinfer + flash-attn + flash-attn-layer-norm + flash-attn-rotary + grpc-interceptor + grpcio-reflection + grpcio-status + grpcio-tools + hf-transfer + loguru + mamba-ssm + marlin-kernels + opentelemetry-api + opentelemetry-exporter-otlp + opentelemetry-instrumentation-grpc + opentelemetry-semantic-conventions + peft + safetensors + sentencepiece + tokenizers + transformers + typer + vllm + ]; + + prePatch = '' + python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \ + --grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto + find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch server/text_generation_server/pb/__init__.py + cd server + ''; +} From 947441509580c56e9bd8160bf2cf23c653d19abd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 21 Aug 2024 07:48:13 +0200 Subject: [PATCH 04/32] nix: add `text-generation-benchmark` to pure devshell (#2431) nix: add text-generation-benchmark to pure devshell --- flake.nix | 4 +++ nix/crate-overrides.nix | 68 ++++++++++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/flake.nix b/flake.nix index adc70fd1..1c9c77a9 100644 --- a/flake.nix +++ b/flake.nix @@ -43,6 +43,9 @@ ]; }; crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; + benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override { + inherit crateOverrides; + }; launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; }; @@ -57,6 +60,7 @@ pure = mkShell { buildInputs = [ + benchmark launcher router server diff --git a/nix/crate-overrides.nix b/nix/crate-overrides.nix index 343b3b25..cbf366af 100644 --- a/nix/crate-overrides.nix +++ b/nix/crate-overrides.nix @@ -20,34 +20,52 @@ defaultCrateOverrides rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; grpc-metadata = attrs: { - src = - filter { - root = ../backends/grpc-metadata; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + src = filter { + root = ../backends/grpc-metadata; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; - text-generation-launcer = attrs: { - src = - filter { - root = ../launcher; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + text-generation-benchmark = attrs: { + src = filter { + root = ../benchmark; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-client = attrs: { + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "backends/client") (matchExt "rs")) + (and (inDirectory "proto") (matchExt "proto")) + ]; + }; + postPatch = "cd backends/client"; + buildInputs = [ protobuf ]; + }; + text-generation-launcher = attrs: { + src = filter { + root = ../launcher; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; text-generation-router = attrs: { - src = - filter { - root = ../router; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + src = filter { + root = ../router; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; text-generation-router-v3 = attrs: { # We need to do the src/source root dance so that the build From 310778e02a5aa36a4c72601e4e929bbbea0f1e7b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Aug 2024 09:06:33 +0200 Subject: [PATCH 05/32] Adding eetq to flake. (#2438) --- flake.lock | 12 ++++++------ nix/server.nix | 2 ++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flake.lock b/flake.lock index c20dd98a..69f9ef13 100644 --- a/flake.lock +++ b/flake.lock @@ -835,11 +835,11 @@ ] }, "locked": { - "lastModified": 1723602049, - "narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=", + "lastModified": 1724206841, + "narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "ea0bf33a11a26a62c60123c49d96011da396602c", + "rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", "type": "github" }, "original": { @@ -944,11 +944,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1723973328, - "narHash": "sha256-q5FmW4YFQcRb6fXHnrxL0uno6xcw9dcg+pFBbVM1xeQ=", + "lastModified": 1724218652, + "narHash": "sha256-Y7Kt+AZRIdo7tr/VhKGzdwYf7stiYQ4JD7flusEpXQw=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "d2038f36589a8a179834e5771ffd081620ba94c3", + "rev": "ab2761aa7b970e737492b8cc41ca580dcb094808", "type": "github" }, "original": { diff --git a/nix/server.nix b/nix/server.nix index ff40757a..1f90e3fd 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -4,6 +4,7 @@ poetry-core, mypy-protobuf, causal-conv1d, + eetq, einops, exllamav2, fbgemm-gpu, @@ -66,6 +67,7 @@ buildPythonPackage { pythonRemoveDeps = [ "scipy" ]; dependencies = [ + eetq causal-conv1d einops exllamav2 From 358ceb67dd343f4022537a1f28bc3fc9baec9102 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 21 Aug 2024 22:20:03 +0200 Subject: [PATCH 06/32] nix: add awq-inference-engine as server dependency (#2442) --- flake.lock | 6 +++--- nix/server.nix | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 69f9ef13..b40f51b3 100644 --- a/flake.lock +++ b/flake.lock @@ -944,11 +944,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1724218652, - "narHash": "sha256-Y7Kt+AZRIdo7tr/VhKGzdwYf7stiYQ4JD7flusEpXQw=", + "lastModified": 1724270760, + "narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "ab2761aa7b970e737492b8cc41ca580dcb094808", + "rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c", "type": "github" }, "original": { diff --git a/nix/server.nix b/nix/server.nix index 1f90e3fd..4e0fdaa1 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -3,6 +3,7 @@ buildPythonPackage, poetry-core, mypy-protobuf, + awq-inference-engine, causal-conv1d, eetq, einops, @@ -67,6 +68,7 @@ buildPythonPackage { pythonRemoveDeps = [ "scipy" ]; dependencies = [ + awq-inference-engine eetq causal-conv1d einops From f3c5d7d92f005c3cd6a801a33334fb9ba32f55f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 23 Aug 2024 22:06:22 +0200 Subject: [PATCH 07/32] nix: add default package (#2453) The default package wraps the launcher and puts the server/router in the path. As a result, TGI can be started using something like: ``` nix run .# -- \ --model-id hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 \ --port 8080 ``` --- flake.nix | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/flake.nix b/flake.nix index 1c9c77a9..83feb26a 100644 --- a/flake.nix +++ b/flake.nix @@ -99,6 +99,17 @@ ''; }; }; + + packages.default = pkgs.writeShellApplication { + name = "text-generation-inference"; + runtimeInputs = [ + server + router + ]; + text = '' + ${launcher}/bin/text-generation-launcher "$@" + ''; + }; } ); } From 30be188400d27b6fedd88cb3dfd88de45639703c Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Aug 2024 17:04:46 -0400 Subject: [PATCH 08/32] Fix: don't apply post layernorm in SiglipVisionTransformer (#2459) * Fix: don't apply post layernorm in SiglipVisionTransformer This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813). This also makes Siglip consistent with the existing Clip implementation: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613 * fix: adjust pali gemma for post layer norm and small refactors --------- Co-authored-by: Travis Addair --- .../custom_modeling/flash_pali_gemma_modeling.py | 10 +++++++++- .../models/custom_modeling/siglip.py | 13 +------------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index d10efb41..e08a2aad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module): config=config.vision_config, weights=weights, ) + self.post_vision_tower_layernorm = nn.LayerNorm.load( + prefix="vision_tower.vision_model.post_layernorm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) self.multi_modal_projector = TensorParallelColumnLinear.load( config, @@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): if pixel_values is not None: pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) image_outputs = self.vision_tower(pixel_values) - image_features = self.multi_modal_projector(image_outputs.last_hidden_state) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) # mask where image or padding tokens mask = input_ids == self.config.image_token_index diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 480d0f9f..95ac9ede 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module): inputs_embeds, attention_mask: Optional[torch.Tensor] = None, ): - hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): hidden_states, _ = encoder_layer( @@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) - self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, ): - r""" - Returns: - - """ if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module): inputs_embeds=hidden_states, ) last_hidden_state = encoder_outputs - post_last_hidden_state = self.post_layernorm(last_hidden_state) return BaseModelOutputWithPooling( - last_hidden_state=post_last_hidden_state, + last_hidden_state=last_hidden_state, # pooler_output=pooled_output, # hidden_states=encoder_outputs, ) From cfa73b5c99bc009903fbc340f8b77a6d4674455d Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Aug 2024 20:19:38 -0400 Subject: [PATCH 09/32] Pr 2451 ci branch (#2454) * fix[router]: Fix tools not passed in chat template Signed-off-by: GitHub * feat: improve default tool serialization and lints * feat: refactor tool logic to include notify_error in prompt and adjust typing * fix: adjust non tool template apply * fix: simplify tool grammar logic and improve schema * feat: avoid skip tool test and avoid empty tool prompts * fix: increase test client timeout for grammar compilation tests --------- Signed-off-by: GitHub Co-authored-by: Simone Rossi --- Cargo.lock | 1 + clients/python/text_generation/client.py | 7 +- docs/openapi.json | 2 +- integration-tests/conftest.py | 2 +- integration-tests/models/test_tools_llama.py | 50 +++--- router/Cargo.toml | 2 +- router/src/infer/chat_template.rs | 57 ++++--- router/src/infer/mod.rs | 6 +- router/src/infer/tool_grammar.rs | 121 +++++++------- router/src/lib.rs | 15 +- router/src/server.rs | 160 ++++++++++++++++--- 11 files changed, 268 insertions(+), 155 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d298c379..aa5cb642 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", + "serde_json", ] [[package]] diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 12966747..45301b63 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -757,7 +757,12 @@ class AsyncClient: continue payload = byte_payload.decode("utf-8") if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + payload_data = ( + payload.lstrip("data:").rstrip("\n").removeprefix(" ") + ) + if payload_data == "[DONE]": + break + json_payload = json.loads(payload_data) try: response = ChatCompletionChunk(**json_payload) yield response diff --git a/docs/openapi.json b/docs/openapi.json index df21e19d..fd64a3ab 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -924,7 +924,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 15af1cad..a8a77cd2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -257,7 +257,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator): class LauncherHandle: def __init__(self, port: int): - self.client = AsyncClient(f"http://localhost:{port}") + self.client = AsyncClient(f"http://localhost:{port}", timeout=30) def _inner_health(self): raise NotImplementedError diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index f831990a..9855cfda 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -36,6 +36,7 @@ tools = [ }, }, "required": ["location", "format"], + "additionalProperties": False, }, }, }, @@ -62,13 +63,13 @@ tools = [ }, }, "required": ["location", "format", "num_days"], + "additionalProperties": False, }, }, }, ] -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna max_tokens=100, seed=1, tools=tools, - presence_penalty=-1.1, + temperature=0.0, messages=[ { "role": "system", @@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="auto", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( @@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 38 + assert count == 48 assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_insufficient_information( @@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information( ): responses = await flash_llama_grammar_tools.chat( max_tokens=100, - seed=8, + seed=24, tools=tools, tool_choice="auto", messages=[ { "role": "system", - "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", }, { "role": "user", @@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( ) assert responses.choices[0].message.content is None - assert responses.choices[0].message.tool_calls == [ - { - "function": { - "arguments": { - "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." - }, - "description": None, - "name": "notify_error", - }, - "id": 0, - "type": "function", - } - ] - + assert ( + responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" + ) assert responses == response_snapshot diff --git a/router/Cargo.toml b/router/Cargo.toml index 7773e212..45acab8e 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } -minijinja = { version = "2.0.2" } +minijinja = { version = "2.0.2", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" regex = "1.10.3" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index a8537818..bfa9421c 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,9 +1,7 @@ use std::collections::HashSet; use crate::infer::InferError; -use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, -}; +use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -32,6 +30,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + tracing::debug!("Loading template: {:#?}", template_str); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) @@ -42,6 +41,7 @@ impl ChatTemplate { let variables = template.undeclared_variables(true); // check if the `tools` variable is used in the template let use_default_tool_template = !variables.contains("tools"); + tracing::debug!("Use default tool template: {}", use_default_tool_template); Self { template, @@ -56,25 +56,36 @@ impl ChatTemplate { &self, guideline: Option<&str>, mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - // check if guideline is expected but not provided if self.variables.contains("guideline") && guideline.is_none() { return Err(InferError::MissingTemplateVariable("guideline".to_string())); } + let tools = match tools_and_prompt { + Some((tools, tool_prompt)) => { + // check if the `tools` variable is used in the template + // if not, we need to append the tools to the last message + let text = if self.use_default_tool_template { + match serde_json::to_string(&tools) { + Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Err(e) => return Err(InferError::ToolError(e.to_string())), + } + } else { + // if the `tools` variable is used in the template, we just append the tool_prompt + format!("\n---\n{}", tool_prompt) + }; + if let Some(last_message) = messages.last_mut() { + last_message.content.push(MessageChunk::Text { text }); + } + Some(tools) + } + None => None, + }; + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + self.template .render(ChatTemplateInputs { guideline, @@ -82,8 +93,7 @@ impl ChatTemplate { bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, - tools: None, - tools_prompt: None, + tools, }) .map_err(InferError::TemplateError) } @@ -95,7 +105,7 @@ mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, }; use minijinja::Environment; @@ -854,11 +864,12 @@ mod tests { content: MessageContent::SingleText("Just testing".to_string()), }, ]; - let tools = serde_json::json!("[]"); + let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); + let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); - let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); - let result = ct.apply(None, msgs, Some(grammer_with_prompt)); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + let tools_and_prompt = Some((tools, tool_prompt)); + let result = ct.apply(None, msgs, tools_and_prompt); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index c9354d9a..81c0d38f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -3,7 +3,7 @@ mod chat_template; pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; -use crate::GrammarType; +use crate::Tool; use crate::{ ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, Message, PrefillToken, Token, @@ -140,12 +140,12 @@ impl Infer { &self, guideline: Option, messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 05027f30..4fe15720 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,5 +1,8 @@ use crate::infer::InferError; -use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use crate::{ + FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, + ToolType, +}; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -16,17 +19,38 @@ impl ToolGrammar { } pub fn apply( - tools: Option>, + tools: Vec, tool_choice: ToolChoice, - ) -> Result, InferError> { + ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None - let tools = match tools { - Some(tools) if !tools.is_empty() => tools, - _ => return Ok(None), - }; + if tools.is_empty() { + return Ok((tools, None)); + } let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + let mut tools = tools.clone(); + + // add the notify_error function to the tools + let notify_error = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "notify_error".to_string(), + description: Some("Notify an error or issue".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "error": { + "type": "string", + "description": "The error or issue to notify" + } + }, + "required": ["error"] + }), + }, + }; + tools.push(notify_error); + // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { ToolType::FunctionName(name) => { @@ -35,87 +59,57 @@ impl ToolGrammar { ToolType::Function { function } => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools, - ToolType::NoTool => return Ok(None), + ToolType::OneOf => tools.clone(), + ToolType::NoTool => return Ok((tools, None)), }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - let functions: HashMap = tools_to_use .iter() .map(|tool| { let func = tool.function.clone(); - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; + let mut params = Map::new(); - // Insert the function's description at the top level, outside of properties params.insert( "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), + Value::String(func.description.unwrap_or_default()), ); - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); + let mut properties = Map::new(); + let mut required = vec![Value::String("_name".to_string())]; - // Insert the constant for the function name inside 'properties' properties.insert( "_name".to_string(), json!({ "type": "string", "const": func.name.clone(), - // "description": "The name of the function" }), ); - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); + if let Value::Object(args) = func.arguments { + if let Some(Value::Object(props)) = args.get("properties") { + properties.extend(props.clone()); + } + if let Some(Value::Array(reqs)) = args.get("required") { + required.extend(reqs.clone()); + } + params.insert( + "additionalProperties".to_string(), + Value::Bool( + args.get("additionalProperties").and_then(|v| v.as_str()) + == Some("true"), + ), + ); } + params.insert("properties".to_string(), Value::Object(properties)); + params.insert("required".to_string(), Value::Array(required)); + (func.name, Value::Object(params)) }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) .collect(); - let tools = Tools { + let tool_schema = JsonSchemaTool { functions_map: FunctionsMap { functions }, properties: Properties { function: tools_to_use @@ -123,13 +117,10 @@ impl ToolGrammar { .map(|tool| FunctionRef { ref_path: format!("#/$functions/{}", tool.function.name.clone()), }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) .collect(), }, }; - Ok(Some(tools)) + Ok((tools, Some(tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 1b2ff153..ce4f7c46 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -840,10 +840,10 @@ pub(crate) struct ChatRequest { pub tools: Option>, /// A prompt to be appended before the tools - #[serde(default = "default_tool_prompt")] + #[serde(default)] #[schema( nullable = true, - example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" + example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." )] pub tool_prompt: Option, @@ -865,10 +865,8 @@ pub(crate) struct ChatRequest { pub guideline: Option, } -fn default_tool_prompt() -> Option { - Some( - "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), - ) +pub fn default_tool_prompt() -> String { + "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] @@ -910,7 +908,7 @@ impl From for ToolChoice { } #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] -pub struct Tools { +pub struct JsonSchemaTool { #[serde(flatten)] functions_map: FunctionsMap, properties: Properties, @@ -968,8 +966,7 @@ pub(crate) struct ChatTemplateInputs<'a> { bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, - tools: Option<&'a str>, - tools_prompt: Option<&'a str>, + tools: Option>, guideline: Option<&'a str>, } diff --git a/router/src/server.rs b/router/src/server.rs index 8ec7a871..8ebd1a33 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::ChatTokenizeResponse; +use crate::{default_tool_prompt, ChatTokenizeResponse}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -23,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -146,7 +146,7 @@ async fn get_chat_tokenize( } = req; let tool_prompt = tool_prompt.unwrap_or_default(); - let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + let (inputs, _grammar, _using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1158,14 +1158,16 @@ async fn chat_completions( let repetition_penalty = presence_penalty.map(|x| x + 2.0); let max_new_tokens = max_tokens.or(Some(100)); let logprobs = logprobs.unwrap_or(false); - let tool_prompt = tool_prompt.unwrap_or_default(); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); let stop = stop.unwrap_or_default(); // enable greedy only when temperature is 0 let (do_sample, temperature) = match temperature { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - let (inputs, grammar, tool_grammar) = prepare_chat_input( + let (inputs, grammar, using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1221,7 +1223,7 @@ async fn chat_completions( }); // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if tool_grammar.is_some() { + let (content, tool_calls) = if using_tools { (None, Some(vec![stream_token.token.text])) } else { let content = if !stream_token.token.special { @@ -1275,7 +1277,7 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let (tool_calls, output) = if tool_grammar.is_some() { + let (tool_calls, output) = if using_tools { let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| { InferError::ToolError(format!( @@ -2539,7 +2541,7 @@ fn create_post_processor( Ok(post_processor) } -type PreparedInput = (String, Option, Option); +type PreparedInput = (String, Option, bool); fn prepare_chat_input( infer: &Infer, @@ -2556,19 +2558,139 @@ fn prepare_chat_input( )); } + // when response_format is set, tools are not included when applying the chat template to generate inputs if let Some(format) = response_format { let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), None)); + return Ok((inputs, Some(format), false)); } - // if tools are set, apply the tool grammar and then the chat template - let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; - let grammar = tool_grammar - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; - Ok((inputs, grammar, tool_grammar)) + // when no response_format is set and tools are included, apply the chat template with the tools + // to generate inputs + if let Some(tools) = tools { + let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; + + let grammar = tool_schema + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt.into())), + )?; + return Ok((inputs, grammar, tool_schema.is_some())); + } + + // if no response_format or tools are set simply apply the chat template to generate inputs + let inputs = infer.apply_chat_template(guideline, messages, None)?; + Ok((inputs, None, false)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ChatTemplateVersions; + use crate::HubTokenizerConfig; + use crate::TokenizerConfigToken; + use crate::Tool; + + use serde_json::json; + + #[test] + fn test_prepare_chat_input() { + // Mock Backend to avoid network requests + struct MockBackend; + + impl Backend for MockBackend { + fn schedule( + &self, + _request: crate::validation::ValidGenerateRequest, + ) -> Result< + tokio_stream::wrappers::UnboundedReceiverStream< + Result, + >, + InferError, + > { + unimplemented!("Never called in this test"); + } + fn health<'a, 'async_trait>( + &'a self, + _current_health: bool, + ) -> core::pin::Pin< + Box + core::marker::Send + 'async_trait>, + > + where + 'a: 'async_trait, + Self: 'async_trait, + { + unimplemented!("Never called in this test"); + } + } + + let backend = MockBackend {}; + + let mut tokenizer_config = HubTokenizerConfig::default(); + + // mock tokenizer config values + tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.chat_template = Some( + ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) + ); + + let infer = Infer::new( + backend, + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + 1, + tokenizer_config, + HubProcessorConfig::default(), + ); + let response_format = None; + let tools = Some(vec![Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "get_current_weather".to_string(), + description: Some("Get the current weather".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + }), + }, + }]); + let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; + let guideline = None; + let messages = vec![Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in New York?".to_string(), + ), + }]; + + let result = prepare_chat_input( + &infer, + response_format, + tools, + ToolChoice(None), + tool_prompt, + guideline, + messages, + ); + + assert!(result.is_ok()); + let (inputs, _grammar, using_tools) = result.unwrap(); + assert_eq!(using_tools, true); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + } } From 2788d41a76c193d4de7055dc5ef38a97f25c38b5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Aug 2024 15:33:02 +0200 Subject: [PATCH 10/32] Fixing CI. (#2462) --- .github/workflows/build.yaml | 4 ---- .github/workflows/ci_build.yaml | 3 +++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 1f72c46d..d415f369 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -32,10 +32,6 @@ jobs: permissions: contents: write packages: write - # This is used to complete the identity challenge - # with sigstore/fulcio when running outside of PRs. - id-token: write - security-events: write steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index 6000cec3..5190f321 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -39,6 +39,9 @@ jobs: matrix: hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"] uses: ./.github/workflows/build.yaml # calls the one above ^ + permissions: + contents: write + packages: write with: hardware: ${{ matrix.hardware }} # https://github.com/actions/runner/issues/2206 From 21187c27c90acbec7f912b8af4feaec154de960f Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 27 Aug 2024 13:31:08 -0400 Subject: [PATCH 11/32] fix: bump minijinja version and add test for llama 3.1 tools (#2463) * fix: support tojson and avoid message indexing issue in template * fix: prefer minijinja native methods and prefer workspace level dependency * fix: adjust comment typo --- Cargo.lock | 4 ++-- Cargo.toml | 2 ++ backends/v3/Cargo.toml | 4 ++-- router/Cargo.toml | 4 ++-- router/src/infer/chat_template.rs | 37 +++++++++++++++++++++++++++++-- 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index aa5cb642..02e91bc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2169,9 +2169,9 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" +checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad" dependencies = [ "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 8bf75b90..79fda15d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,8 @@ tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } +minijinja = { version = "2.2.0", features = ["json"] } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } [profile.release] incremental = true diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 06a44bec..69dad072 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -53,8 +53,8 @@ utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } -minijinja = { version = "2.0.2" } -minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +minijinja = { workspace = true } +minijinja-contrib = { workspace = true } futures-util = "0.3.30" regex = "1.10.3" once_cell = "1.19.0" diff --git a/router/Cargo.toml b/router/Cargo.toml index 45acab8e..5c328e8a 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,8 +46,8 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } -minijinja = { version = "2.0.2", features = ["json"] } -minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +minijinja = { workspace = true } +minijinja-contrib = { workspace = true } futures-util = "0.3.30" regex = "1.10.3" once_cell = "1.19.0" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index bfa9421c..a736fc12 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,9 +1,8 @@ -use std::collections::HashSet; - use crate::infer::InferError; use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; +use std::collections::HashSet; /// Raise a exception (custom function) used in the chat templates pub(crate) fn raise_exception(err_text: String) -> Result { @@ -872,4 +871,38 @@ mod tests { let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } + + #[test] + fn test_chat_template_with_custom_tool_template() { + // chat template from meta-llama/Meta-Llama-3.1-8B-Instruct + let ct = ChatTemplate::new( + "{{- bos_token }}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + let msgs: Vec = vec![ + Message { + name: None, + role: "system".to_string(), + content: MessageContent::SingleText( + "Youre a helpful assistant! Answer the users question best you can." + .to_string(), + ), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in Brooklyn, New York?".to_string(), + ), + }, + ]; + let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); + let tools: Vec = serde_json::from_str(&tools_string).unwrap(); + let tool_prompt = "This default prompt will be used".to_string(); + let tools_and_prompt = Some((tools, tool_prompt)); + let result = ct.apply(None, msgs, tools_and_prompt); + let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); + assert_eq!(result.unwrap(), expected); + } } From 8f99f165ce1a261c89ea2edef437ef23c03a0716 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 28 Aug 2024 13:44:44 -0400 Subject: [PATCH 12/32] fix: improve regex expression (#2468) --- docs/source/basic_tutorials/using_guidance.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md index f09ef348..dfa3f0e4 100644 --- a/docs/source/basic_tutorials/using_guidance.md +++ b/docs/source/basic_tutorials/using_guidance.md @@ -157,7 +157,12 @@ from huggingface_hub import InferenceClient client = InferenceClient("http://localhost:3000") -regexp = "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)" +section_regex = "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)" +regexp = f"HELLO\.{section_regex}\.WORLD\.{section_regex}" + +# This is a more realistic example of an ip address regex +# regexp = f"{section_regex}\.{section_regex}\.{section_regex}\.{section_regex}" + resp = client.text_generation( f"Whats Googles DNS? Please use the following regex: {regexp}", @@ -170,7 +175,7 @@ resp = client.text_generation( print(resp) -# 7.1.1.1 +# HELLO.255.WORLD.255 ``` From 4e821c003a7cb055a358cf142dbf01a2f4c1916f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 29 Aug 2024 16:25:25 +0200 Subject: [PATCH 13/32] nix: build Torch against MKL and various other improvements (#2469) Updates tgi-nix input: - Move Torch closer to upstream by building against MKL. - Remove compute capability 8.7 from Torch (Jetson). - Sync nixpkgs cumpute capabilities with Torch (avoids compiling too mana capabilities for MAGMA). - Use nixpkgs configuration passed through by `tgi-nix`. --- flake.lock | 6 +++--- flake.nix | 9 +++------ nix/server.nix | 1 + 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/flake.lock b/flake.lock index b40f51b3..14011768 100644 --- a/flake.lock +++ b/flake.lock @@ -944,11 +944,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1724270760, - "narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=", + "lastModified": 1724784743, + "narHash": "sha256-NdEoWeNwR/ZstYnHaiQWIYZvr7VsrAh7g3+ZHUPrxuI=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c", + "rev": "c9580c3e39a855246bb87b584bbea1885b44f524", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 83feb26a..0739c90a 100644 --- a/flake.nix +++ b/flake.nix @@ -31,15 +31,12 @@ src = ./.; additionalCargoNixArgs = [ "--all-features" ]; }; - config = { - allowUnfree = true; - cudaSupport = true; - }; pkgs = import nixpkgs { - inherit config system; + inherit system; + inherit (tgi-nix.lib) config; overlays = [ rust-overlay.overlays.default - tgi-nix.overlay + tgi-nix.overlays.default ]; }; crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; diff --git a/nix/server.nix b/nix/server.nix index 4e0fdaa1..6ee088e0 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -28,6 +28,7 @@ peft, safetensors, tokenizers, + torch, sentencepiece, transformers, typer, From e415b690a68d7a0e149c996e46def41c867ff421 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 29 Aug 2024 16:29:01 +0200 Subject: [PATCH 14/32] Lots of improvements (Still 2 allocators) (#2449) * Making prefix/flashinfer the default and testing the full release tests. * Include flashinfer in the docker. * Using prebuilt. * Allowing window_left_size (dummy version). * Disabling flashinfer/prefix caching on odd head_dim * Disable prefix caching for lora. * More specific codes. * Update lock * Updating integration tests with new values with FI/FD. Remove paged as a default too, and using FD everywhere. * Update cargo lock ? * Upgrade to 1.80 because of bitstream... * Everywhere 1.80 * Forgot last default place. * Apply suggestions from code review Co-authored-by: drbh * Updated flake lock * Tmp * Upgrade resolution system for less errors in resolution. * Remove lambda for cleaner function. * Handling debugger. * OVerride the env in server tests. * Is this enough to make it work ? * This seems to be working. * Downgrade some logs. * Fixing the default for vlm. * Don't enable prefix caching on VLM just yet. * Change `add_special_tokens` in order to have the correct tokens for chat input and not (since it's super important with the prefixing now) * Fixing prefix caching for flashdecoding. * Update all models. * Fixed flashinfer version. * add_special_tokens is internal only * Fixing seqlen with the new vlms. * Fixing the issue with `add_special_tokens` not being passed around. * Fixing the test. * Removing encoder_decoder (seq2seq). * Update the chat test. * Fixing the batching tokenization in flash causal lm. * Truncating left for radix purposes. * Oops this doesn't belong here. * Put back default pure shell. * Update server tests - Default to throughput test in k6 - Use TGI_WIGGLE_ROOM to adjust wiggle room * Only n_heads / process_group.size() are necessary. * Revert the integrationt tests change (seem linked to head_size modification). * Adding error message when assert is violated. * Fixing the free algorithm to handle times where the common prefix is smaller. * Apply suggestions from code review Co-authored-by: OlivierDehaene * Update server/text_generation_server/layers/attention/common.py Co-authored-by: OlivierDehaene * Fix disabling prefix caching - Fix windowing checks. * Revert the Cohere tokenizer change (for now using a revision instead). * Fmt. --------- Co-authored-by: drbh Co-authored-by: OlivierDehaene --- .github/workflows/tests.yaml | 2 +- Cargo.lock | 437 +++++++++--------- Dockerfile | 9 +- Dockerfile_amd | 2 +- Dockerfile_intel | 2 +- backends/client/src/v3/client.rs | 2 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/src/backend.rs | 30 +- backends/v3/src/block_allocator.rs | 5 +- backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/sharded_client.rs | 1 + backends/v3/src/queue.rs | 2 + backends/v3/src/radix.rs | 206 ++++++--- benchmark/src/generation.rs | 1 + flake.lock | 6 +- .../test_flash_llama_simple.json | 12 +- ..._llama_completion_many_prompts_stream.json | 172 +++---- .../test_flash_deepseek_v2.json | 44 +- .../test_flash_deepseek_v2_load.json | 152 +++--- .../test_flash_llama_fp8_all_params.json | 58 +-- .../test_flash_starcoder2_default_params.json | 16 +- .../test_flash_idefics2_next_all_params.json | 8 +- integration-tests/models/test_chat_llama.py | 2 +- launcher/src/main.rs | 265 ++++++++--- load_tests/common.js | 26 +- proto/v3/generate.proto | 2 + router/src/infer/mod.rs | 3 +- router/src/lib.rs | 21 + router/src/server.rs | 4 + router/src/validation.rs | 49 +- rust-toolchain.toml | 2 +- server/Makefile | 1 + server/Makefile-flashinfer | 2 + server/tests/conftest.py | 5 +- .../layers/attention/common.py | 39 +- .../layers/attention/cuda.py | 28 +- .../text_generation_server/models/__init__.py | 15 +- .../custom_modeling/flash_cohere_modeling.py | 23 +- .../custom_modeling/flash_dbrx_modeling.py | 27 +- .../flash_deepseek_v2_modeling.py | 24 +- .../custom_modeling/flash_gemma2_modeling.py | 23 +- .../custom_modeling/flash_gemma_modeling.py | 23 +- .../custom_modeling/flash_gpt2_modeling.py | 23 +- .../custom_modeling/flash_gptj_modeling.py | 25 +- .../custom_modeling/flash_llama_modeling.py | 23 +- .../custom_modeling/flash_mistral_modeling.py | 25 +- .../custom_modeling/flash_mixtral_modeling.py | 25 +- .../custom_modeling/flash_neox_modeling.py | 25 +- .../flash_pali_gemma_modeling.py | 5 +- .../custom_modeling/flash_phi_modeling.py | 23 +- .../custom_modeling/flash_qwen2_modeling.py | 25 +- .../custom_modeling/flash_rw_modeling.py | 33 +- .../flash_santacoder_modeling.py | 23 +- .../flash_starcoder2_modeling.py | 25 +- .../models/custom_modeling/idefics2.py | 5 +- .../models/custom_modeling/llava_next.py | 5 +- .../models/flash_causal_lm.py | 105 +++-- .../text_generation_server/models/globals.py | 9 +- .../models/vlm_causal_lm.py | 11 +- 59 files changed, 1234 insertions(+), 934 deletions(-) create mode 100644 server/Makefile-flashinfer diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f983b6ed..6faabe3b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -35,7 +35,7 @@ jobs: with: # Released on: 02 May, 2024 # https://releases.rs/docs/1.78.0/ - toolchain: 1.79.0 + toolchain: 1.80.0 override: true components: rustfmt, clippy - name: Install Protoc diff --git a/Cargo.lock b/Cargo.lock index 02e91bc1..00c7f005 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -28,7 +34,7 @@ dependencies = [ "once_cell", "serde", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -121,14 +127,14 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "async-rustls" @@ -160,7 +166,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -171,7 +177,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -257,9 +263,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e89b6941c2d1a7045538884d6e760ccfffdf8e1ffc2613d8efa74305e1f3752" +checksum = "0f0e249228c6ad2d240c2dc94b714d711629d52bad946075d8e9b2f5391f0703" dependencies = [ "bindgen", "cc", @@ -402,7 +408,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.4", "object", "rustc-demangle", ] @@ -444,7 +450,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.72", + "syn 2.0.76", "which", ] @@ -483,9 +489,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.5.0" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499" +checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" [[package]] name = "block-buffer" @@ -516,9 +522,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" [[package]] name = "byteorder" @@ -534,15 +540,15 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.1" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "camino" -version = "1.1.7" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0ec6b951b160caa93cc0c7b209e5a3bff7aae9062213451ac99493cd844c239" +checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" dependencies = [ "serde", ] @@ -584,12 +590,13 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.7" +version = "1.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" +checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6" dependencies = [ "jobserver", "libc", + "shlex", ] [[package]] @@ -623,6 +630,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "clang-sys" version = "1.8.1" @@ -647,9 +660,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.11" +version = "4.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" +checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" dependencies = [ "clap_builder", "clap_derive", @@ -657,9 +670,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.11" +version = "4.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" +checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" dependencies = [ "anstream", "anstyle", @@ -669,14 +682,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.11" +version = "4.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" +checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -687,9 +700,9 @@ checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "cmake" -version = "0.1.50" +version = "0.1.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" dependencies = [ "cc", ] @@ -741,15 +754,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" dependencies = [ "libc", ] @@ -897,19 +910,19 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.4.4" +version = "3.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" dependencies = [ - "nix", - "windows-sys 0.52.0", + "nix 0.29.0", + "windows-sys 0.59.0", ] [[package]] name = "cxx" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +checksum = "3c4eae4b7fc8dcb0032eb3b1beee46b38d371cdeaf2d0c64b9944f6f69ad7755" dependencies = [ "cc", "cxxbridge-flags", @@ -919,9 +932,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +checksum = "6c822bf7fb755d97328d6c337120b6f843678178751cba33c9da25cf522272e0" dependencies = [ "cc", "codespan-reporting", @@ -929,24 +942,24 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "cxxbridge-flags" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" +checksum = "719d6197dc016c88744aff3c0d0340a01ecce12e8939fc282e7c8f583ee64bc6" [[package]] name = "cxxbridge-macro" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +checksum = "35de3b547387863c8f82013c4f79f1c2162edee956383e4089e1d04c18c4f16c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -970,7 +983,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -981,7 +994,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1011,7 +1024,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1021,7 +1034,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1057,9 +1070,9 @@ dependencies = [ [[package]] name = "dunce" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "easy-cast" @@ -1126,7 +1139,7 @@ dependencies = [ "flume", "half 2.4.1", "lebe", - "miniz_oxide", + "miniz_oxide 0.7.4", "rayon-core", "smallvec", "zune-inflate", @@ -1144,9 +1157,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fdeflate" @@ -1165,12 +1178,12 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" dependencies = [ "crc32fast", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] @@ -1296,7 +1309,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1405,7 +1418,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.2.6", + "indexmap 2.4.0", "slab", "tokio", "tokio-util", @@ -1414,9 +1427,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" dependencies = [ "atomic-waker", "bytes", @@ -1424,7 +1437,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.2.6", + "indexmap 2.4.0", "slab", "tokio", "tokio-util", @@ -1631,7 +1644,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.5", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -1689,9 +1702,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", @@ -1774,9 +1787,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -1832,7 +1845,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1915,9 +1928,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -1932,7 +1945,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap 4.5.11", + "clap 4.5.16", "fancy-regex", "fraction", "getrandom", @@ -1972,9 +1985,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" [[package]] name = "libfuzzer-sys" @@ -2126,7 +2139,7 @@ dependencies = [ "hyper 1.4.1", "hyper-rustls", "hyper-util", - "indexmap 2.2.6", + "indexmap 2.4.0", "ipnet", "metrics", "metrics-util", @@ -2179,9 +2192,9 @@ dependencies = [ [[package]] name = "minijinja-contrib" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6853ef2340c668281c5ea86b04da2ebb2fc9e98a7185a887591de4cac945d5b5" +checksum = "744a2b84dbd22398e347594ed2aef9d3f1b948934e3e6e94ef69ecd39d597f4b" dependencies = [ "minijinja", "serde", @@ -2203,6 +2216,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "0.8.11" @@ -2217,9 +2239,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ "hermit-abi 0.3.9", "libc", @@ -2251,7 +2273,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2341,7 +2363,19 @@ checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ "bitflags 2.6.0", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", + "libc", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases 0.2.1", "libc", ] @@ -2439,7 +2473,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2510,9 +2544,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.2" +version = "0.36.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" +checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" dependencies = [ "memchr", ] @@ -2574,7 +2608,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2613,7 +2647,7 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ "futures-core", "futures-sink", - "indexmap 2.2.6", + "indexmap 2.4.0", "js-sys", "once_cell", "pin-project-lite", @@ -2837,7 +2871,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.4.0", ] [[package]] @@ -2857,7 +2891,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2916,7 +2950,7 @@ dependencies = [ "crc32fast", "fdeflate", "flate2", - "miniz_oxide", + "miniz_oxide 0.7.4", ] [[package]] @@ -2933,21 +2967,21 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy 0.6.6", + "zerocopy", ] [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" dependencies = [ "proc-macro2", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2999,7 +3033,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3039,7 +3073,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.72", + "syn 2.0.76", "tempfile", ] @@ -3066,7 +3100,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3110,9 +3144,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -3201,9 +3235,9 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.9" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5797d09f9bd33604689e87e8380df4951d4912f01b63f71205e2abd4ae25e6b6" +checksum = "a8f0bfd976333248de2078d350bfdf182ff96e168a24d23d2436cef320dd4bdd" dependencies = [ "avif-serialize", "imgref", @@ -3264,9 +3298,9 @@ dependencies = [ [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", @@ -3275,9 +3309,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -3359,9 +3393,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.45" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade4539f42266ded9e755c605bdddf546242b2c961b03b06a7375260788a0523" +checksum = "0f86ae463694029097b846d8f99fd5536740602ae00022c0c50c5600720b2f71" dependencies = [ "bytemuck", ] @@ -3416,7 +3450,7 @@ dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.72", + "syn 2.0.76", "walkdir", ] @@ -3507,12 +3541,12 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" +checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "rustls-pki-types", "schannel", "security-framework", @@ -3529,9 +3563,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -3539,15 +3573,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" -version = "0.102.6" +version = "0.102.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" +checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" dependencies = [ "aws-lc-rs", "ring 0.17.8", @@ -3641,9 +3675,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] @@ -3660,20 +3694,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "serde_json" -version = "1.0.121" +version = "1.0.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" +checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" dependencies = [ "itoa", "memchr", @@ -3875,7 +3909,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3897,9 +3931,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.72" +version = "2.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" +checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" dependencies = [ "proc-macro2", "quote", @@ -3993,20 +4027,21 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4024,7 +4059,7 @@ version = "2.2.1-dev0" dependencies = [ "async-stream", "async-trait", - "clap 4.5.11", + "clap 4.5.16", "cmake", "cxx", "cxx-build", @@ -4046,7 +4081,7 @@ name = "text-generation-benchmark" version = "2.2.1-dev0" dependencies = [ "average", - "clap 4.5.11", + "clap 4.5.16", "crossterm", "float-ord", "hf-hub", @@ -4084,11 +4119,11 @@ dependencies = [ name = "text-generation-launcher" version = "2.2.1-dev0" dependencies = [ - "clap 4.5.11", + "clap 4.5.16", "ctrlc", "float_eq", "hf-hub", - "nix", + "nix 0.28.0", "once_cell", "reqwest", "serde", @@ -4108,7 +4143,7 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap 4.5.11", + "clap 4.5.16", "csv", "futures", "futures-util", @@ -4156,7 +4191,7 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap 4.5.11", + "clap 4.5.16", "criterion", "futures", "futures-util", @@ -4224,7 +4259,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4341,14 +4376,14 @@ dependencies = [ [[package]] name = "tokio" -version = "1.39.2" +version = "1.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" dependencies = [ "backtrace", "bytes", "libc", - "mio 1.0.1", + "mio 1.0.2", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -4375,7 +4410,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4437,9 +4472,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.16" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", @@ -4449,20 +4484,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.17" +version = "0.22.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" +checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.4.0", "serde", "serde_spanned", "toml_datetime", @@ -4534,7 +4569,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4575,15 +4610,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -4605,7 +4640,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4865,7 +4900,7 @@ version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.4.0", "serde", "serde_json", "utoipa-gen", @@ -4881,7 +4916,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4919,7 +4954,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5000,34 +5035,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" dependencies = [ "cfg-if", "js-sys", @@ -5037,9 +5073,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5047,28 +5083,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", @@ -5149,11 +5185,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5208,6 +5244,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.42.2" @@ -5388,9 +5433,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.16" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" +checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" dependencies = [ "memchr", ] @@ -5405,34 +5450,14 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "zerocopy" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" -dependencies = [ - "byteorder", - "zerocopy-derive 0.6.6", -] - [[package]] name = "zerocopy" version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy-derive" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", + "byteorder", + "zerocopy-derive", ] [[package]] @@ -5443,7 +5468,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5463,7 +5488,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] diff --git a/Dockerfile b/Dockerfile index 4c64a643..0d0e89b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -184,6 +184,12 @@ WORKDIR /usr/src COPY server/Makefile-selective-scan Makefile RUN make build-all +# Build flashinfer +FROM kernel-builder AS flashinfer-builder +WORKDIR /usr/src +COPY server/Makefile-flashinfer Makefile +RUN make install-flashinfer + # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base @@ -236,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/ # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/Dockerfile_amd b/Dockerfile_amd index cdad0d28..8cb699dd 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/Dockerfile_intel b/Dockerfile_intel index 12480c70..9af6422c 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,6 +1,6 @@ ARG PLATFORM=xpu -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index b321278c..479d31bf 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -153,6 +153,8 @@ impl Client { }), // We truncate the input on the server side to be sure that it has the correct size truncate, + // Most request will have that + add_special_tokens: true, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 1cc173e3..645c076a 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -221,6 +221,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index cbcbff72..05a26370 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,27 +35,15 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { - matches!(prefix_caching.as_str(), "true" | "1") - } else { - false - }; - let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) - } else if prefix_caching { - Attention::FlashInfer - } else { - Attention::Paged - }; - let block_size = if attention == Attention::FlashDecoding { - 256 - } else if attention == Attention::FlashInfer { - 1 - } else { - 16 - }; + let prefix_caching = + std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); + let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); + let attention: String = std::env::var("ATTENTION").expect("attention env var"); + + let attention: Attention = attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); + let block_size = attention.block_size(); let queue = Queue::new( requires_padding, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index c5503b8c..4fea172b 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,4 +1,4 @@ -use std::{cmp::min, sync::Arc}; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; @@ -137,7 +137,6 @@ pub trait Allocator { fn free(&mut self, blocks: Vec, allocation_id: u64); } - pub struct SimpleAllocator { free_blocks: Vec, block_size: u32, @@ -167,7 +166,7 @@ impl Allocator for SimpleAllocator { None => (tokens, 1), Some(window_size) => { let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); + let tokens = core::cmp::min(tokens, window_size); (tokens, repeats as usize) } }; diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 6282759e..648662db 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -149,6 +149,7 @@ impl Client { requests.push(Request { id: 0, inputs, + add_special_tokens: true, input_chunks: Some(Input { chunks: input_chunks, }), diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 2f78da03..ea77a696 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -222,6 +222,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index faa57c11..2a8c4c53 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -383,6 +383,7 @@ impl State { }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, + add_special_tokens: entry.request.add_special_tokens, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), @@ -517,6 +518,7 @@ mod tests { inputs: vec![], input_ids: Some(Arc::new(vec![])), input_length: 0, + add_special_tokens: true, truncate: 0, decoder_input_details: false, parameters: ValidParameters { diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 5bac1a31..b85be00b 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -1,12 +1,10 @@ +use crate::block_allocator::{Allocator, BlockAllocation}; +use slotmap::{DefaultKey, SlotMap}; use std::{ collections::{BTreeSet, HashMap}, sync::Arc, }; -use slotmap::{DefaultKey, SlotMap}; - -use crate::block_allocator::{Allocator, BlockAllocation}; - pub struct RadixAllocator { allocation_id: u64, @@ -16,26 +14,26 @@ pub struct RadixAllocator { /// Blocks that are immediately available for allocation. free_blocks: Vec, + + #[allow(dead_code)] + // This isn't used because the prefix need to match without the windowing + // mecanism. This at worst is overallocating, not necessarily being wrong. + window_size: Option, + + block_size: u32, } impl RadixAllocator { pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { - assert_eq!( - block_size, 1, - "Radix tree allocator only works with block_size=1, was: {}", - block_size - ); - if window_size.is_some() { - unimplemented!("Window size not supported in the prefix-caching block allocator yet"); - } - RadixAllocator { allocation_id: 0, allocations: HashMap::new(), - cache_blocks: RadixTrie::new(), + cache_blocks: RadixTrie::new(block_size as usize), // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), + window_size, + block_size, } } @@ -63,6 +61,7 @@ impl RadixAllocator { } } +// Allocator trait impl Allocator for RadixAllocator { fn allocate( &mut self, @@ -86,10 +85,12 @@ impl Allocator for RadixAllocator { .incref(prefix_node) .expect("Failed to increment refcount"); - let prefix_len = blocks.len(); + let prefix_len = blocks.len() * self.block_size as usize; let suffix_len = tokens - prefix_len as u32; - match self.alloc_or_reclaim(suffix_len as usize) { + let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { self.cache_blocks @@ -100,7 +101,20 @@ impl Allocator for RadixAllocator { } // 1:1 mapping of blocks and slots. - let slots = blocks.clone(); + let slots = if self.block_size == 1 { + blocks.clone() + } else { + let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); + 'slots: for block_id in &blocks { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() as u32 == tokens { + break 'slots; + } + } + } + slots + }; let allocation = RadixAllocation { prefix_node, @@ -108,6 +122,8 @@ impl Allocator for RadixAllocator { prefill_tokens: prefill_tokens.clone(), }; + tracing::debug!("Blocks {blocks:?}"); + self.allocation_id += 1; self.allocations.insert(self.allocation_id, allocation); @@ -136,27 +152,38 @@ impl Allocator for RadixAllocator { // If there are prefill tokens that did not come from the cache, // add them to the cache. if prefill_tokens.len() > allocation.cached_prefix_len { - let prefix_len = self - .cache_blocks - .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) - // Unwrap, failing is a programming error. - .expect("Failed to store prefill tokens"); - - // We can have a prefill with the following structure: - // - // |---| From the prefix cache. - // A B C D E F G - //|--------| Found in the trie during insertion. - // - // This means that while processing this request there was a - // partially overlapping request that had A..=E in its - // prefill. In this case we need to free the blocks D E. - self.free_blocks - .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + let aligned = + (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; + if aligned > 0 { + let prefix_len = self + .cache_blocks + .insert( + &prefill_tokens[..aligned], + &blocks[..aligned / self.block_size as usize], + ) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + if prefix_len > allocation.cached_prefix_len { + self.free_blocks.extend( + &blocks[allocation.cached_prefix_len / self.block_size as usize + ..prefix_len / self.block_size as usize], + ); + } + } } // Free non-prefill blocks. - self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + self.free_blocks + .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); } else { self.free_blocks.extend(blocks); } @@ -204,17 +231,14 @@ pub struct RadixTrie { /// Time as a monotonically increating counter to avoid the system /// call that a real time lookup would require. time: u64, -} -impl Default for RadixTrie { - fn default() -> Self { - Self::new() - } + /// All blocks need to be aligned with this + block_size: usize, } impl RadixTrie { /// Construct a new radix trie. - pub fn new() -> Self { + pub fn new(block_size: usize) -> Self { let root = TrieNode::new(vec![], vec![], 0, None); let mut nodes = SlotMap::new(); let root = nodes.insert(root); @@ -223,13 +247,14 @@ impl RadixTrie { nodes, root, time: 0, + block_size, } } /// Find the prefix of the given tokens. /// /// The blocks corresponding to the part of the prefix that could be found - /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. /// Returns the identifier of the trie node that contains the longest /// prefix. The node identifier can be used by callers to e.g. increase its /// reference count. @@ -247,8 +272,9 @@ impl RadixTrie { if let Some(&child_id) = node.children.get(&key[0]) { self.update_access_time(child_id); let child = self.nodes.get(child_id).expect("Invalid child identifier"); - let shared_prefix_len = child.key.shared_prefix_len(key); - blocks.extend(&child.blocks[..shared_prefix_len]); + let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + assert_eq!(shared_prefix_len % self.block_size, 0); + blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); let key = &key[shared_prefix_len..]; if !key.is_empty() { @@ -349,7 +375,8 @@ impl RadixTrie { /// the first 10 elements of the tree **the blocks are not updated**. pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { self.time += 1; - self.insert_(self.root, tokens, blocks) + let common = self.insert_(self.root, tokens, blocks)?; + Ok(common) } /// Insertion worker. @@ -363,7 +390,7 @@ impl RadixTrie { // the part of the prefix that is already in the trie to detect // mismatches. - if tokens.len() != blocks.len() { + if tokens.len() != blocks.len() * self.block_size { return Err(TrieError::BlockTokenCountMismatch); } @@ -374,10 +401,10 @@ impl RadixTrie { .get_mut(child_id) // Unwrap here, since failure is a bug. .expect("Child node does not exist"); - let shared_prefix_len = child.key.shared_prefix_len(tokens); + let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); // We are done, the prefix is already in the trie. - if shared_prefix_len == tokens.len() { + if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { return Ok(shared_prefix_len); } @@ -387,7 +414,7 @@ impl RadixTrie { + self.insert_( child_id, &tokens[shared_prefix_len..], - &blocks[shared_prefix_len..], + &blocks[shared_prefix_len / self.block_size..], )?); } @@ -396,7 +423,7 @@ impl RadixTrie { // remainder of the prefix into the node again let child_id = self.split_node(child_id, shared_prefix_len); let key = &tokens[shared_prefix_len..]; - let blocks = &blocks[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len / self.block_size..]; Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) } else { self.add_node(node_id, tokens, blocks); @@ -550,34 +577,53 @@ impl TrieNode { } } -/// Helper trait to get the length of the shared prefix of two sequences. -trait SharedPrefixLen { - fn shared_prefix_len(&self, other: &Self) -> usize; -} - -impl SharedPrefixLen for [T] -where - T: PartialEq, -{ - fn shared_prefix_len(&self, other: &Self) -> usize { - self.iter().zip(other).take_while(|(a, b)| a == b).count() - } +fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { + let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + (full / block_size) * block_size } #[cfg(test)] mod tests { use std::sync::Arc; - use crate::block_allocator::Allocator; + use super::*; - use super::RadixAllocator; + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_block_size_non_aligned() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 2); + } #[test] fn allocator_reuses_prefixes() { let mut cache = RadixAllocator::new(1, 12, None); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); - assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.blocks, allocation.slots); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); @@ -666,7 +712,7 @@ mod tests { #[test] fn trie_insertions_have_correct_prefix_len() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); @@ -687,9 +733,33 @@ mod tests { ); } + #[test] + fn trie_insertions_block_size() { + let mut trie = RadixTrie::new(2); + + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0); + + // Already exists. + // But needs to be block_size aligned + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) + .unwrap(), + 2 + ); + } + #[test] fn trie_get_returns_correct_blocks() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); @@ -723,7 +793,7 @@ mod tests { #[test] fn trie_evict_removes_correct_blocks() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) .unwrap(); diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 7494d5b5..789c7b51 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -148,6 +148,7 @@ async fn prefill( }), inputs: sequence.clone(), truncate: sequence_length, + add_special_tokens: true, parameters: Some(parameters.clone()), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: decode_length, diff --git a/flake.lock b/flake.lock index 14011768..c0a696b1 100644 --- a/flake.lock +++ b/flake.lock @@ -835,11 +835,11 @@ ] }, "locked": { - "lastModified": 1724206841, - "narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", + "lastModified": 1724638882, + "narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", + "rev": "19b70f147b9c67a759e35824b241f1ed92e46694", "type": "github" }, "original": { diff --git a/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json index 8631c076..5553e17d 100644 --- a/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json +++ b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas", + "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1716553098, + "created": 1724792495, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "2.0.5-dev0-native", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", "usage": { "completion_tokens": 100, - "prompt_tokens": 62, - "total_tokens": 162 + "prompt_tokens": 61, + "total_tokens": 161 } } diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json index d87071cf..e7fb5740 100644 --- a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json @@ -8,11 +8,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -23,11 +23,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -38,11 +38,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -53,11 +53,11 @@ "text": "hd" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -68,11 +68,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -83,11 +83,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -98,11 +98,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -113,11 +113,11 @@ "text": "aho" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -128,11 +128,11 @@ "text": "2" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -143,11 +143,11 @@ "text": "2" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -158,11 +158,11 @@ "text": "2" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -173,11 +173,11 @@ "text": "ima" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -188,11 +188,11 @@ "text": "." } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -203,11 +203,11 @@ "text": "." } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -218,11 +218,11 @@ "text": "." } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -233,11 +233,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -248,11 +248,11 @@ "text": " Sarah" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -263,11 +263,11 @@ "text": " Yes" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -278,11 +278,11 @@ "text": " And" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -293,11 +293,11 @@ "text": "i" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -308,11 +308,11 @@ "text": "'" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -323,11 +323,11 @@ "text": "," } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -338,11 +338,11 @@ "text": " what" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -353,11 +353,11 @@ "text": "'" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -368,11 +368,11 @@ "text": "s" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -383,11 +383,11 @@ "text": " Moh" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -398,11 +398,11 @@ "text": " is" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -413,11 +413,11 @@ "text": "m" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -428,11 +428,11 @@ "text": " Room" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -443,11 +443,11 @@ "text": "s" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -458,11 +458,11 @@ "text": " the" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -473,11 +473,11 @@ "text": " tired" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -488,11 +488,11 @@ "text": ":" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -503,11 +503,11 @@ "text": "'" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -518,11 +518,11 @@ "text": " capital" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -530,73 +530,73 @@ "finish_reason": "", "index": 3, "logprobs": null, - "text": " of" + "text": "," } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 0, "logprobs": null, "text": " She" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 1, "logprobs": null, "text": " scale" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 2, "logprobs": null, "text": " of" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 3, "logprobs": null, - "text": " being" + "text": " its" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json index 03f90367..732b0c49 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json @@ -16,7 +16,7 @@ }, { "id": 3102, - "logprob": -11.1875, + "logprob": -11.25, "text": " request" } ], @@ -24,66 +24,66 @@ "tokens": [ { "id": 185, - "logprob": -1.5546875, + "logprob": -1.546875, "special": false, "text": "\n" }, { "id": 549, - "logprob": -2.84375, + "logprob": -2.859375, "special": false, "text": "The" }, { "id": 1727, - "logprob": -2.34375, + "logprob": -2.484375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.8359375, + "logprob": -0.83203125, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.0859375, + "logprob": -1.1484375, "special": false, "text": " is" }, { - "id": 254, - "logprob": -1.5390625, + "id": 245, + "logprob": -1.578125, "special": false, - "text": " the" + "text": " a" }, { - "id": 1022, - "logprob": -1.1875, + "id": 3412, + "logprob": -2.578125, "special": false, - "text": " first" + "text": " document" }, { - "id": 3458, - "logprob": -0.35546875, + "id": 344, + "logprob": -1.125, "special": false, - "text": " step" + "text": " that" }, { - "id": 279, - "logprob": -0.8828125, + "id": 317, + "logprob": -1.6953125, "special": false, - "text": " in" + "text": " is" }, { - "id": 254, - "logprob": -0.71484375, + "id": 1222, + "logprob": -1.71875, "special": false, - "text": " the" + "text": " used" } ], "top_tokens": null }, - "generated_text": "\nThe test request is the first step in the" + "generated_text": "\nThe test request is a document that is used" } diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json index e365829a..f1eeab25 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -37,56 +37,56 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" }, { "details": { @@ -126,56 +126,56 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" }, { "details": { @@ -215,56 +215,56 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" }, { "details": { @@ -304,55 +304,55 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json index bf981e4f..e39829ec 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "stop_sequence", + "generated_tokens": 5, "prefill": [ { "id": 128000, @@ -16,7 +16,7 @@ }, { "id": 1715, - "logprob": -10.375, + "logprob": -10.4375, "text": " request" } ], @@ -29,61 +29,31 @@ "text": ":" }, { - "id": 2209, - "logprob": -2.78125, + "id": 923, + "logprob": -2.84375, "special": false, - "text": " Is" + "text": " add" }, { - "id": 279, - "logprob": -0.6328125, + "id": 264, + "logprob": 0.0, "special": false, - "text": " the" - }, - { - "id": 734, - "logprob": -2.703125, - "special": false, - "text": " function" + "text": " a" }, { "id": 330, - "logprob": -0.34179688, + "logprob": -0.31640625, "special": false, "text": " \"" }, { - "id": 4110, - "logprob": -2.359375, + "id": 1985, + "logprob": 0.0, "special": false, - "text": "Create" - }, - { - "id": 7575, - "logprob": -2.1875, - "special": false, - "text": "Process" - }, - { - "id": 1, - "logprob": -0.07910156, - "special": false, - "text": "\"" - }, - { - "id": 304, - "logprob": -0.83203125, - "special": false, - "text": " in" - }, - { - "id": 12468, - "logprob": -1.8203125, - "special": false, - "text": " Win" + "text": "test" } ], "top_tokens": null }, - "generated_text": "Test request: Is the function \"CreateProcess\" in Win" + "generated_text": "Test request: add a \"test" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json index d882b82a..412b19b4 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -16,7 +16,7 @@ }, { "id": 100, - "logprob": -0.38549805, + "logprob": -0.38305664, "text": "_" }, { @@ -29,7 +29,7 @@ "tokens": [ { "id": 2284, - "logprob": -0.31323242, + "logprob": -0.296875, "special": false, "text": "():" }, @@ -59,19 +59,19 @@ }, { "id": 10914, - "logprob": -0.7817383, + "logprob": -0.7734375, "special": false, "text": " World" }, { "id": 16013, - "logprob": -0.6328125, + "logprob": -0.61816406, "special": false, "text": "!\")" }, { "id": 222, - "logprob": -0.0619812, + "logprob": -0.054870605, "special": false, "text": "\n" }, @@ -83,7 +83,7 @@ }, { "id": 610, - "logprob": -0.4086914, + "logprob": -0.4152832, "special": false, "text": "def" }, @@ -113,7 +113,7 @@ }, { "id": 444, - "logprob": -0.21826172, + "logprob": -0.21618652, "special": false, "text": "name" }, @@ -173,7 +173,7 @@ }, { "id": 11571, - "logprob": -0.10021973, + "logprob": -0.08892822, "special": false, "text": "!\"" }, diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json index 1fad0b96..dab437b9 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json @@ -30,19 +30,19 @@ }, { "id": 264, - "logprob": -0.37573242, + "logprob": -0.38061523, "special": false, "text": " a" }, { "id": 633, - "logprob": -0.09161377, + "logprob": -0.09301758, "special": false, "text": " new" }, { "id": 4480, - "logprob": -0.26171875, + "logprob": -0.26782227, "special": false, "text": " feature" }, @@ -78,7 +78,7 @@ }, { "id": 13, - "logprob": 0.0, + "logprob": -0.10632324, "special": false, "text": "\n" } diff --git a/integration-tests/models/test_chat_llama.py b/integration-tests/models/test_chat_llama.py index 1f7a4a59..7d24add3 100644 --- a/integration-tests/models/test_chat_llama.py +++ b/integration-tests/models/test_chat_llama.py @@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot): print(repr(response.choices[0].message.content)) assert ( response.choices[0].message.content - == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas" + == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C" ) assert response == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9a90a673..8e5c9dcd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -8,7 +8,7 @@ use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; -use std::io::{BufRead, BufReader, Lines}; +use std::io::{BufRead, BufReader}; use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; use std::process::{Child, Command, ExitStatus, Stdio}; @@ -18,12 +18,103 @@ use std::sync::{mpsc, Arc}; use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; -use std::{fs, io}; +use std::{ + fs, io, + io::{Read, Write}, +}; use thiserror::Error; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; +fn get_config( + model_id: &str, + revision: &Option, +) -> Result> { + let mut path = std::path::Path::new(model_id).to_path_buf(); + let model_id = model_id.to_string(); + let filename = if !path.exists() { + // Assume it's a hub id + + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path + }; + + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; + + let config: Config = config.into(); + Ok(config) +} + +fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { + let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); + let mut attention: Option = std::env::var("ATTENTION").ok(); + if let Some(config) = config { + if prefix_caching.is_none() { + if config.vision_config.is_some() { + tracing::info!("Disabling prefix caching because of VLM model"); + prefix_caching = Some("0".to_string()); + } else if config.is_encoder_decoder { + tracing::info!("Disabling prefix caching because of seq2seq model"); + prefix_caching = Some("0".to_string()); + } + } + match config.head_dim { + Some(h) if h == 64 || h == 128 || h == 256 => { + if lora_adapters.is_some() && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because of lora adapters"); + prefix_caching = Some("0".to_string()); + } + match config.model_type.as_deref() { + Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { + // Required because gemma2 needs bfloat16 which is not supported by + // flashinfer ? + if attention.is_none() { + tracing::info!( + "Forcing flash decoding because model {} requires it", + config.model_type.as_ref().unwrap() + ); + attention = Some("flashdecoding".to_string()); + } + } + Some("t5") => {} + _ => {} + } + } + _ => { + if attention.is_none() { + tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); + attention = Some("flashdecoding".to_string()); + } + if prefix_caching.is_none() { + prefix_caching = Some("0".to_string()); + } + } + } + } + let prefix_caching = prefix_caching.unwrap_or("true".to_string()); + let attention = attention.unwrap_or("flashinfer".to_string()); + (prefix_caching, attention) +} + #[derive(Deserialize)] struct RawConfig { max_position_embeddings: Option, @@ -31,6 +122,12 @@ struct RawConfig { model_type: Option, max_seq_len: Option, quantization_config: Option, + n_embd: Option, + hidden_size: Option, + num_attention_heads: Option, + head_dim: Option, + vision_config: Option, + is_encoder_decoder: Option, } #[derive(Deserialize)] @@ -38,10 +135,17 @@ struct QuantizationConfig { quant_method: Option, } +#[derive(Deserialize)] +struct VisionConfig {} + #[derive(Deserialize)] struct Config { max_position_embeddings: Option, quantize: Option, + head_dim: Option, + model_type: Option, + vision_config: Option, + is_encoder_decoder: bool, } impl From for Config { @@ -51,9 +155,32 @@ impl From for Config { .or(other.max_seq_len) .or(other.n_positions); let quantize = other.quantization_config.and_then(|q| q.quant_method); + let head_dim = other.head_dim.or_else(|| { + match (other.hidden_size, other.n_embd, other.num_attention_heads) { + (Some(hidden_size), _, Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + // Legacy + (_, Some(hidden_size), Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + _ => None, + } + }); + let model_type = other.model_type; + let vision_config = other.vision_config; + let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); Config { max_position_embeddings, quantize, + head_dim, + model_type, + vision_config, + is_encoder_decoder, } } } @@ -731,6 +858,7 @@ fn shard_manager( .args(shard_args) .env_clear() .envs(envs) + .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .process_group(0) @@ -752,12 +880,13 @@ fn shard_manager( }; // Redirect STDOUT to the console + let mut pstdin = p.stdin.take().unwrap(); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); //stdout tracing thread thread::spawn(move || { - log_lines(shard_stdout_reader.lines()); + log_lines(shard_stdout_reader); }); // We read stderr in another thread as it seems that lines() can block in some cases let (err_sender, err_receiver) = mpsc::channel(); @@ -766,6 +895,18 @@ fn shard_manager( err_sender.send(line).unwrap_or(()); } }); + // We read stdin in another thread as it seems that lines() can block in some cases + thread::spawn(move || { + let mut stdin = io::stdin(); // We get `Stdin` here. + loop { + let mut buffer = vec![0; 4096]; + if let Ok(n) = stdin.read(&mut buffer) { + if n > 0 { + let _ = pstdin.write_all(&buffer[..n]); + } + } + } + }); let mut ready = false; let start_time = Instant::now(); @@ -872,19 +1013,36 @@ impl PythonLogMessage { } } -impl TryFrom<&String> for PythonLogMessage { +impl TryFrom<&[u8]> for PythonLogMessage { type Error = serde_json::Error; - fn try_from(value: &String) -> Result { - serde_json::from_str::(value) + fn try_from(value: &[u8]) -> Result { + serde_json::from_slice::(value) } } -fn log_lines(lines: Lines) { - for line in lines.map_while(Result::ok) { - match PythonLogMessage::try_from(&line) { - Ok(log) => log.trace(), - Err(_) => tracing::debug!("{line}"), +fn log_lines(mut bufread: BufReader) { + let mut buffer = vec![0u8; 8 * 4096]; + let mut stdout = std::io::stdout(); + loop { + let n = bufread.read(&mut buffer); + if let Ok(n) = n { + if n > 0 { + let mut lines = buffer[..n].split(|i| *i == b'\n').peekable(); + while let Some(line) = lines.next() { + match PythonLogMessage::try_from(line) { + Ok(log) => log.trace(), + // For interactive debugging ? + Err(_) => { + stdout.write_all(line).unwrap(); + if lines.peek().is_some() { + stdout.write_all(b"\n").unwrap(); + } + stdout.flush().unwrap(); + } + } + } + } } } } @@ -1044,7 +1202,7 @@ fn download_convert_model( let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); thread::spawn(move || { - log_lines(download_stdout.lines()); + log_lines(download_stdout); }); let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); @@ -1439,68 +1597,35 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_max_positions_quantize = - || -> Result<(usize, Option), Box> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id + let config: Option = get_config(&args.model_id, &args.revision).ok(); + let quantize = config.as_ref().and_then(|c| c.quantize); + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? - } else { - Api::new()? - }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; - - if config.model_type == Some("gemma2".to_string()) { - tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - let config: Config = config.into(); - let quantize = config.quantize; - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok((max_default, quantize)) - } else { - Ok((max_position_embeddings, quantize)) + let max_position_embeddings = if let Some(config) = &config { + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); } + max_default } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) + max_position_embeddings } - }; - let (max_position_embeddings, quantize): (usize, Option) = - get_max_positions_quantize().unwrap_or((4096, None)); + } else { + max_default + } + } else { + max_default + }; + let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); + tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); + std::env::set_var("USE_PREFIX_CACHING", prefix_caching); + std::env::set_var("ATTENTION", attention); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { diff --git a/load_tests/common.js b/load_tests/common.js index e0a10595..d890bf67 100644 --- a/load_tests/common.js +++ b/load_tests/common.js @@ -33,13 +33,13 @@ export function get_options() { // rate: 20, // timeUnit: '1s', // }, - load_test: { - executor: 'constant-arrival-rate', - duration: '60s', - preAllocatedVUs: 100, - rate: 1, - timeUnit: '1s', - }, + // load_test: { + // executor: 'constant-arrival-rate', + // duration: '60s', + // preAllocatedVUs: 100, + // rate: 1, + // timeUnit: '1s', + // }, // breakpoint: { // executor: 'ramping-arrival-rate', //Assure load increase if the system slows // preAllocatedVUs: 300, @@ -47,12 +47,12 @@ export function get_options() { // { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load // ], // }, - // throughput: { - // executor: 'shared-iterations', - // vus: 100, - // iterations: 200, - // maxDuration: '40s', - // }, + throughput: { + executor: 'shared-iterations', + vus: 100, + iterations: 200, + maxDuration: '40s', + }, }, }; } diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 68eea7ac..34894bda 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -137,6 +137,8 @@ message Request { optional string adapter_id = 11; /// Prefix length that can be retrieved from the KV cache. uint32 prefix_len = 12; + /// Context truncation + bool add_special_tokens = 13; } message Batch { diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 81c0d38f..240282d9 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -120,10 +120,11 @@ impl Infer { ) -> Result, InferError> { // Tokenize request let inputs = request.inputs; + let add_special_tokens = request.add_special_tokens; let truncate = request.parameters.truncate; let encoding = self .validation - .tokenize(inputs, truncate) + .tokenize(inputs, add_special_tokens, truncate) .await .map_err(|err| { tracing::error!("Tokenization {err}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index ce4f7c46..979f6dd1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -22,6 +22,16 @@ pub enum Attention { FlashInfer, } +impl Attention { + pub fn block_size(&self) -> u32 { + match self { + Attention::FlashDecoding => 256, + Attention::FlashInfer => 1, + Attention::Paged => 16, + } + } +} + #[derive(Debug)] pub struct ParseError; @@ -1072,6 +1082,16 @@ pub(crate) struct GenerateRequest { pub inputs: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, + + /// This is used internally because some requests + /// already contain the templated input therefore + /// we shouldn't add the special tokens. + #[serde(default = "default_true", skip)] + pub add_special_tokens: bool, +} + +fn default_true() -> bool { + true } #[derive(Clone, Debug, Deserialize, ToSchema)] @@ -1089,6 +1109,7 @@ impl From for GenerateRequest { fn from(req: CompatGenerateRequest) -> Self { Self { inputs: req.inputs, + add_special_tokens: true, parameters: req.parameters, } } diff --git a/router/src/server.rs b/router/src/server.rs index 8ebd1a33..f273a786 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -158,6 +158,7 @@ async fn get_chat_tokenize( let generate_request = GenerateRequest { inputs, + add_special_tokens: false, parameters: GenerateParameters { best_of: None, temperature, @@ -754,6 +755,7 @@ async fn completions( .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), + add_special_tokens: true, parameters: GenerateParameters { best_of: None, temperature, @@ -1180,6 +1182,7 @@ async fn chat_completions( // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), + add_special_tokens: false, parameters: GenerateParameters { best_of: None, temperature, @@ -1386,6 +1389,7 @@ async fn vertex_compatibility( .map(|instance| { let generate_request = GenerateRequest { inputs: instance.inputs.clone(), + add_special_tokens: true, parameters: GenerateParameters { do_sample: true, max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), diff --git a/router/src/validation.rs b/router/src/validation.rs index 0024723c..92491d88 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -95,6 +95,7 @@ impl Validation { pub async fn tokenize( &self, inputs: String, + add_special_tokens: bool, truncate: Option, ) -> Result)>, ValidationError> { // If we have a fast tokenizer @@ -104,7 +105,11 @@ impl Validation { // Send request to the background validation task // Unwrap is safe here sender - .send(((inputs, truncate), response_sender, Span::current())) + .send(( + (inputs, add_special_tokens, truncate), + response_sender, + Span::current(), + )) .unwrap(); // Await on response channel @@ -121,11 +126,15 @@ impl Validation { async fn validate_input( &self, inputs: String, + add_special_tokens: bool, truncate: Option, max_new_tokens: Option, ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer - if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { + if let Some((encoding, inputs)) = self + .tokenize(inputs.clone(), add_special_tokens, truncate) + .await? + { // Create response channel let input_length = if let Some(truncate) = truncate { std::cmp::min(encoding.len(), truncate) @@ -158,7 +167,8 @@ impl Validation { )); } - let input_ids = encoding.get_ids()[..input_length].to_owned(); + let ids = encoding.get_ids(); + let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, Some(input_ids), input_length, max_new_tokens)) @@ -324,7 +334,12 @@ impl Validation { // Validate inputs let (inputs, input_ids, input_length, max_new_tokens) = self - .validate_input(request.inputs, truncate, max_new_tokens) + .validate_input( + request.inputs, + request.add_special_tokens, + truncate, + max_new_tokens, + ) .await?; // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar @@ -401,6 +416,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, input_ids: input_ids.map(Arc::new), + add_special_tokens: request.add_special_tokens, decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -449,12 +465,15 @@ fn tokenizer_worker( mut receiver: mpsc::UnboundedReceiver, ) { // Loop over requests - while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { parent_span.in_scope(|| { response_tx .send(prepare_input( inputs, truncate, + add_special_tokens, &tokenizer, config.as_ref(), preprocessor_config.as_ref(), @@ -591,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String { fn prepare_input( inputs: String, _truncate: Option, + add_special_tokens: bool, tokenizer: &Tokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, @@ -628,14 +648,14 @@ fn prepare_input( // Get the number of tokens in the input let encoding = tokenizer - .encode(tokenizer_query, true) + .encode(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; Ok((encoding, input_chunks)) } type TokenizerRequest = ( - (String, Option), + (String, bool, Option), oneshot::Sender), ValidationError>>, Span, ); @@ -720,6 +740,7 @@ pub struct ValidGenerateRequest { pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, + pub add_special_tokens: bool, pub decoder_input_details: bool, pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, @@ -826,7 +847,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { // Err(ValidationError::MaxNewTokens(1, 10)) => (), @@ -861,7 +882,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), @@ -895,6 +916,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { best_of: Some(2), do_sample: false, @@ -934,6 +956,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: Some(1.0), max_new_tokens: Some(5), @@ -949,6 +972,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: Some(0.99), max_new_tokens: Some(5), @@ -964,6 +988,7 @@ mod tests { let valid_request = validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: None, max_new_tokens: Some(5), @@ -1002,6 +1027,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(5), max_new_tokens: Some(5), @@ -1017,6 +1043,7 @@ mod tests { validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(4), max_new_tokens: Some(5), @@ -1029,6 +1056,7 @@ mod tests { validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(0), max_new_tokens: Some(5), @@ -1041,6 +1069,7 @@ mod tests { let valid_request = validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: None, max_new_tokens: Some(5), @@ -1089,6 +1118,7 @@ mod tests { let chunks = match validation .tokenize( format!("test![](data:image/gif;base64,{})", PIXEL_GIF), + true, None, ) .await @@ -1148,6 +1178,7 @@ mod tests { "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", PIXEL_GIF, PIXEL_GIF ), + true, None, ) .await diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 8c77896e..f392b161 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] # Released on: June 13, 2024 # https://releases.rs/docs/1.79.0/ -channel = "1.79.0" +channel = "1.80.0" components = ["rustfmt", "clippy"] diff --git a/server/Makefile b/server/Makefile index 51ea8b32..9338b299 100644 --- a/server/Makefile +++ b/server/Makefile @@ -7,6 +7,7 @@ include Makefile-selective-scan include Makefile-lorax-punica include Makefile-fbgemm include Makefile-exllamav2 +include Makefile-flashinfer unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-flashinfer b/server/Makefile-flashinfer new file mode 100644 index 00000000..3abb0491 --- /dev/null +++ b/server/Makefile-flashinfer @@ -0,0 +1,2 @@ +install-flashinfer: + pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4 diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 16d2c408..d99771f8 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,7 +1,10 @@ import pytest - +import os from text_generation_server.pb import generate_pb2 +os.environ["USE_PREFIX_CACHING"] = "1" +os.environ["ATTENTION"] = "flashinfer" + @pytest.fixture def default_pb_parameters(): diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index f162230c..855f4dfc 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -9,26 +9,46 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: input_lengths: torch.Tensor + prefix_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] + max_q: int + max_k: int - def __init__(self, input_lengths): + def __init__( + self, + input_lengths, + prefix_lengths, + cu_seqlen_q=None, + max_q=None, + max_k=None, + ): self.input_lengths = input_lengths + self.prefix_lengths = prefix_lengths device = self.input_lengths.device shape = self.input_lengths.shape - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + max_q = 1 + else: + assert max_q is not None + assert max_k is not None cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 - torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + total = self.input_lengths + self.prefix_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_k = cu_seqlen_k + self.max_q = max_q + self.max_k = max_k def clamp(self, max): # Flash decoding doesn't need to clamp @@ -39,6 +59,11 @@ else: @dataclass class Seqlen: input_lengths: torch.Tensor + prefix_lengths: torch.Tensor + cu_seqlen_q: torch.Tensor + max_q: int + max_k: int def clamp(self, max): + raise NotImplementedError("Not implemented seqlen for paged") return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index b3b7ea4f..4b588b5c 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -222,18 +222,15 @@ if ATTENTION == "flashinfer": def attention( q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - cu_seqlens, - max_s, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, softcap=0.0, ): - assert window_size_left == -1, "Windowing is not supported with flash infer" from text_generation_server.layers.attention.flashinfer import ( prefill_with_paged_kv_state, ) @@ -244,18 +241,17 @@ if ATTENTION == "flashinfer": paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, + window_left=window_size_left, ) elif V2: def attention( q, - k, - v, key_cache: torch.Tensor, value_cache: torch.Tensor, - cu_seqlens, - max_s, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -266,17 +262,17 @@ elif V2: raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, None, None, + block_tables, None, - None, - max_s, - max_s, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4fa9e66d..e03cc30d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -497,15 +497,14 @@ def get_model( else -1 ) - should_use_sliding_window = ( - sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING + use_sliding_window = sliding_window is not None and sliding_window != -1 + needs_sliding_window = ( + max_input_tokens is not None and max_input_tokens > sliding_window ) - - if should_use_sliding_window: - if max_input_tokens is not None and max_input_tokens > sliding_window: - raise ValueError( - f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." - ) + if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING: + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." + ) if model_type == DEEPSEEK_V2: if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 1eb8c6c3..fe19180a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( @@ -264,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -296,12 +297,10 @@ class FlashCohereAttention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -313,7 +312,7 @@ class FlashCohereAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -388,7 +387,7 @@ class FlashCohereLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -402,7 +401,7 @@ class FlashCohereLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -454,7 +453,7 @@ class FlashCohereModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: torch.Tensor, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -477,7 +476,7 @@ class FlashCohereModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -518,7 +517,7 @@ class FlashCohereForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -531,7 +530,7 @@ class FlashCohereForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index fc0dca5b..b82b5473 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( FastLinear, @@ -309,7 +310,7 @@ class DbrxAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -335,12 +336,10 @@ class DbrxAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -352,7 +351,7 @@ class DbrxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -389,7 +388,7 @@ class DbrxNormAttentionNorm(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -403,7 +402,7 @@ class DbrxNormAttentionNorm(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -622,7 +621,7 @@ class DbrxLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Self Attention @@ -635,7 +634,7 @@ class DbrxLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -679,7 +678,7 @@ class DbrxModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -701,7 +700,7 @@ class DbrxModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -734,7 +733,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -747,7 +746,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index b25becd5..0585b40e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -29,8 +29,8 @@ from text_generation_server.layers.attention import ( attention, paged_attention, reshape_and_cache, + Seqlen, ) -from text_generation_server.layers.attention.common import Seqlen from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.import_utils import SYSTEM @@ -298,7 +298,7 @@ class DeepseekV2Attention(torch.nn.Module): kv_cache: Tuple[torch.Tensor, torch.Tensor], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, ): if self.q_lora_rank is None: @@ -363,12 +363,10 @@ class DeepseekV2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -380,7 +378,7 @@ class DeepseekV2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -666,7 +664,7 @@ class DeepseekV2Layer(nn.Module): kv_cache, block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -680,7 +678,7 @@ class DeepseekV2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -729,7 +727,7 @@ class DeepseekV2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -751,7 +749,7 @@ class DeepseekV2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -781,7 +779,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -794,7 +792,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index faf0f325..d16e805f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, causal=self.causal, window_size_left=self.window_size, @@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, softcap=self.softcap, ) @@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 33738a59..34be4cb8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, causal=self.causal, ) @@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d30b5a0a..403fa908 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): query, key, value = self.query_key_value(hidden_states).split( @@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): residual = hidden_states @@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index eb667384..35ab2791 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): query, key, value = self.query_key_value(hidden_states).split( @@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, - cu_seqlen_prefill, - max_s, + kv_cache[0], + kv_cache[1], + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: @@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices=prefill_cache_indices, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3253d2dc..5b228f9f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -32,6 +32,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 5a150267..30ca3faf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -356,7 +355,7 @@ class MistralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -372,7 +371,7 @@ class MistralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index ad426ffe..c5d60af1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( FastLinear, @@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -498,7 +497,7 @@ class MixtralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -513,7 +512,7 @@ class MixtralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b684e035..fda648f9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module): # flash attention attn_output = attention( qkv[:, 0], - qkv[:, 1], - qkv[:, 2], kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.use_parallel_residual: @@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) @@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index e08a2aad..d044b492 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,6 +19,7 @@ from torch import nn from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear +from text_generation_server.layers.attention import Seqlen from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index efe27c13..37adb8be 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -10,6 +10,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Compute query, key, value and split @@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module): if cu_seqlen_prefill is not None: attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 879b8abd..5aac28a3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index c72a9b90..1c55dd91 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -19,6 +19,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, reshape_and_cache, + Seqlen, ) @@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -343,7 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -429,7 +428,7 @@ class FlashRWLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.parallel_attn: @@ -443,7 +442,7 @@ class FlashRWLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -465,7 +464,7 @@ class FlashRWLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -552,7 +551,7 @@ class FlashRWLargeLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Layer norm. @@ -567,7 +566,7 @@ class FlashRWLargeLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -628,7 +627,7 @@ class FlashRWModel(FlashRWPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) @@ -650,7 +649,7 @@ class FlashRWModel(FlashRWPreTrainedModel): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -680,7 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -693,7 +692,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 109304be..19025c4c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.c_attn(hidden_states) @@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(key_value, dim=1, index=0), - torch.select(key_value, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -373,7 +372,7 @@ class Block(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) @@ -383,7 +382,7 @@ class Block(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 200d4ef0..2f9ecd0d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -209,7 +210,7 @@ class Starcoder2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -240,12 +241,10 @@ class Starcoder2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -258,7 +257,7 @@ class Starcoder2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -381,7 +380,7 @@ class Starcoder2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -396,7 +395,7 @@ class Starcoder2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -449,7 +448,7 @@ class Starcoder2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -473,7 +472,7 @@ class Starcoder2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -521,7 +520,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -534,7 +533,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -543,7 +542,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 7e4deaf8..a829c374 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -25,6 +25,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) +from text_generation_server.layers.attention import Seqlen from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 29f5b9c7..32e9d334 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -23,6 +23,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution +from text_generation_server.layers.attention import Seqlen from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index dd4203e0..9a60d06c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,7 +43,7 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, - PREFIX_CACHING, + TGI_WIGGLE_ROOM, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -189,16 +189,21 @@ class FlashCausalLMBatch(Batch): def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer ): - batch_inputs = [] - max_truncation = 0 + max_length = 0 + all_input_ids = [] + batch_size = 0 for r in requests: - batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] - return batch_tokenized_inputs + batch_size += 1 + inputs = concat_text_chunks(r.input_chunks.chunks) + input_ids = tokenizer( + inputs, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + all_input_ids.append(input_ids) + return all_input_ids @classmethod def from_tokenized( @@ -257,22 +262,15 @@ class FlashCausalLMBatch(Batch): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate :] - if ( - tokenized_input[0] == tokenizer.bos_token_id - and tokenized_input[1] == tokenizer.bos_token_id - ): - tokenized_input = tokenized_input[1:] - orig_input_length = len(tokenized_input) - if PREFIX_CACHING: - prefix_len = r.prefix_len - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 - else: - prefix_len = 0 + prefix_len = r.prefix_len + assert ( + prefix_len <= orig_input_length + ), f"Prefix {prefix_len} vs input {orig_input_length}" + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:] @@ -998,7 +996,7 @@ class FlashCausalLM(Model): config.sliding_window = None self.num_layers = config.num_hidden_layers - self.num_heads = config.num_attention_heads + self.num_heads = config.num_attention_heads // self.process_group.size() # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -1160,8 +1158,15 @@ class FlashCausalLM(Model): "block_tables": block_tables, "slots": slots, "input_lengths": input_lengths_tensor, + "prefix_lengths": prefix_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1204,7 +1209,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths_, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1213,7 +1218,13 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1221,7 +1232,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths_tensor, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1268,7 +1279,7 @@ class FlashCausalLM(Model): num_blocks = ( # Leave 5% for some wiggle room - int((free_memory * 0.95) // total_cache_size) + int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + batch_num_blocks ) @@ -1360,18 +1371,26 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - input_lengths = Seqlen(input_lengths=input_lengths) + prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ) + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=1, + max_k=seqlen, + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, - cu_seqlen_prefill=torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ), + cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, block_tables=None, - input_lengths=input_lengths, + seqlen=seqlen, slots=slots, max_s=seqlen, lm_head_indices=None, @@ -1451,8 +1470,7 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: + if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, @@ -1462,11 +1480,18 @@ class FlashCausalLM(Model): block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=batch.input_lengths, - input_lengths_tensor=input_lengths, + input_lengths_tensor=input_lengths + prefix_lens_tensor, prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, ): - input_lengths = Seqlen(input_lengths=input_lengths) + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1474,7 +1499,7 @@ class FlashCausalLM(Model): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index d5133f5e..6c518c2c 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,19 +5,22 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") +ATTENTION = os.getenv("ATTENTION") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION != "flashinfer": +if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) +assert TGI_WIGGLE_ROOM > 0 +assert TGI_WIGGLE_ROOM < 1 # This is overridden by the cli BLOCK_SIZE: int diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 2ed1a119..d6cb36fa 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM): prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, ): - input_lengths = Seqlen(input_lengths=input_lengths) + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, From d5202c46f7c42ea81f75dbc99f6eb1c5697c6b40 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 29 Aug 2024 10:32:38 -0400 Subject: [PATCH 15/32] feat: add /v1/models endpoint (#2433) * feat: add /v1/models endpoint * feat: add /v1/models endpoint * fix: remove unused type import * fix: revert route typo * fix: update docs with new endpoint * fix: add to redocly ignore and lint --- .redocly.lint-ignore.yaml | 1 + docs/openapi.json | 60 +++++++++++++++++++++++++++++++++++++++ router/src/lib.rs | 28 ++++++++++++++++++ router/src/server.rs | 29 ++++++++++++++++++- 4 files changed, 117 insertions(+), 1 deletion(-) diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml index 382c9ab6..13b80497 100644 --- a/.redocly.lint-ignore.yaml +++ b/.redocly.lint-ignore.yaml @@ -77,3 +77,4 @@ docs/openapi.json: - '#/paths/~1tokenize/post' - '#/paths/~1v1~1chat~1completions/post' - '#/paths/~1v1~1completions/post' + - '#/paths/~1v1~1models/get' diff --git a/docs/openapi.json b/docs/openapi.json index fd64a3ab..691705f2 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -556,6 +556,37 @@ } } } + }, + "/v1/models": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Get model info", + "operationId": "openai_get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ModelInfo" + } + } + } + }, + "404": { + "description": "Model not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } } }, "components": { @@ -1747,6 +1778,35 @@ } ] }, + "ModelInfo": { + "type": "object", + "required": [ + "id", + "object", + "created", + "owned_by" + ], + "properties": { + "created": { + "type": "integer", + "format": "int64", + "example": 1686935002, + "minimum": 0 + }, + "id": { + "type": "string", + "example": "gpt2" + }, + "object": { + "type": "string", + "example": "model" + }, + "owned_by": { + "type": "string", + "example": "openai" + } + } + }, "OutputMessage": { "oneOf": [ { diff --git a/router/src/lib.rs b/router/src/lib.rs index 979f6dd1..a1e1dadf 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1261,6 +1261,34 @@ pub(crate) struct ErrorResponse { pub error_type: String, } +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelInfo { + #[schema(example = "gpt2")] + pub id: String, + #[schema(example = "model")] + pub object: String, + #[schema(example = 1686935002)] + pub created: u64, + #[schema(example = "openai")] + pub owned_by: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelsInfo { + #[schema(example = "list")] + pub object: String, + pub data: Vec, +} + +impl Default for ModelsInfo { + fn default() -> Self { + ModelsInfo { + object: "list".to_string(), + data: Vec::new(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index f273a786..d3d34215 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -24,6 +24,7 @@ use crate::{ VertexResponse, }; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -116,6 +117,29 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( +get, +tag = "Text Generation Inference", +path = "/v1/models", +responses( +(status = 200, description = "Served model info", body = ModelInfo), +(status = 404, description = "Model not found", body = ErrorResponse), +) +)] +#[instrument(skip(info))] +/// Get model info +async fn openai_get_model_info(info: Extension) -> Json { + Json(ModelsInfo { + data: vec![ModelInfo { + id: info.0.model_id.clone(), + object: "model".to_string(), + created: 0, // TODO: determine how to get this + owned_by: info.0.model_id.clone(), + }], + ..Default::default() + }) +} + #[utoipa::path( post, tag = "Text Generation Inference", @@ -1505,6 +1529,7 @@ chat_completions, completions, tokenize, metrics, +openai_get_model_info, ), components( schemas( @@ -1557,6 +1582,7 @@ ToolCall, Function, FunctionDefinition, ToolChoice, +ModelInfo, ) ), tags( @@ -2250,7 +2276,8 @@ async fn start( .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + .route("/v1/models", get(openai_get_model_info)); // Conditional AWS Sagemaker route let aws_sagemaker_route = if messages_api_enabled { From 9883f3b40e8e76cecc9807c438fc08c11e602b7a Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 29 Aug 2024 23:42:02 +0800 Subject: [PATCH 16/32] update doc with intel cpu part (#2420) * update doc with intel cpu part Signed-off-by: Wang, Yi A * Apply suggestions from code review we do not use latest ever in documentation, it causes too many issues for users. Release number get update on every release. --------- Signed-off-by: Wang, Yi A Co-authored-by: Nicolas Patry --- docs/source/installation_intel.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md index b3843490..3084a436 100644 --- a/docs/source/installation_intel.md +++ b/docs/source/installation_intel.md @@ -12,7 +12,24 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.2.0-intel \ + ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \ + --model-id $model --cuda-graphs 0 +``` + +# Using TGI with Intel CPUs + +Intel® Extension for PyTorch (IPEX) also provides further optimizations for Intel CPUs. The IPEX provides optimization operations such as flash attention, page attention, Add + LayerNorm, ROPE and more. + +On a server powered by Intel CPU, TGI can be launched with the following command: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --rm --privileged --cap-add=sys_nice \ + --device=/dev/dri \ + --ipc=host --shm-size 1g --net host -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \ --model-id $model --cuda-graphs 0 ``` From d9fbbaafb046bb423e31edaf9ccf8eecc2d5c33d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 29 Aug 2024 17:44:54 +0200 Subject: [PATCH 17/32] Tied embeddings in MLP speculator. (#2473) * Tied embeddings in MLP speculator. * Fixing the scale_weight when users decide to not use the speculation as much as defined in the config. * Adding scaling support + optimize some ops. --- server/text_generation_server/layers/mlp.py | 120 +++++++++++++++++- .../text_generation_server/models/__init__.py | 5 + 2 files changed, 118 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/layers/mlp.py b/server/text_generation_server/layers/mlp.py index f08cb673..d33b41f3 100644 --- a/server/text_generation_server/layers/mlp.py +++ b/server/text_generation_server/layers/mlp.py @@ -45,12 +45,107 @@ class MLPSpeculatorLayerNorm(nn.Module): return x +INV_SQRT2 = 2**-0.5 + + +def simple_norm(x: torch.Tensor, eps=1e-06): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps) + x = xf.type_as(x) + return x * INV_SQRT2 + + +class MLPSpeculatorModelTied(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.config = config + self.n_predict = get_speculate() + self.hidden_size = config.hidden_size + + self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights) + self.proj0 = FastLinear.load( + config, + prefix=f"{prefix}.proj.0", + weights=weights, + bias=False, + ) + self.proj1 = FastLinear.load( + config, + prefix=f"{prefix}.proj.1", + weights=weights, + bias=False, + ) + self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False) + self.ln = MLPSpeculatorLayerNorm( + prefix=f"{prefix}.ln.0", + config=config, + weights=weights, + ) + + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation + self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 + self.activation = nn.GELU() + self.vsize = config.vocab_size + self.inner_dim = config.speculator_config["inner_dim"] + self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + ): + top_k_tokens_per_head = self.top_k_tokens_per_head + + # k indicates # of candidates + # h indicates # of generated tokens + state = hidden_states + b = state.size(0) + ind = input_ids.unsqueeze(0) + all_probs = torch.empty( + b, self.n_predict, self.vsize, device=state.device + ) # b k h v + assert ( + len(top_k_tokens_per_head) == self.n_predict + ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" + for i in range(self.n_predict): + # Project and predict + z = self.emb(ind) + # z = z.mul(self.emb_weight) # b k d + if i == 0: + state = self.proj0(state) * self.state_weight + z + else: + state = self.proj1(state) * self.state_weight + z + state = self.activation(self.ln(state)) # b k d + probs = F.log_softmax(self.head(state), dim=-1) # b k v + _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' + + # Update candidate set with new predictions + + # Update distribution set with new logits + all_probs[:, i] = probs.exp() + + # Update state, log_probs and ind for new predictions + state = state.unsqueeze(2).expand( + -1, -1, top_k_tokens_per_head[i], -1 + ) # b k k' d + state = state.reshape(-1, b, state.size(3)) # b kk' d + ind = preds.view(-1, b) # b kk' + + speculative_logits = all_probs + return speculative_logits + + class MLPSpeculatorModel(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.config = config self.n_predict = get_speculate() self.hidden_size = config.hidden_size + self.emb = nn.ModuleList( [ TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) @@ -84,13 +179,15 @@ class MLPSpeculatorModel(torch.nn.Module): ) # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation - self.state_weight = 0.5 ** (0.5 / self.n_predict) - self.emb_weight = math.sqrt(1 - self.state_weight**2) + self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 self.activation = nn.GELU() - # TODO self.vsize = config.vocab_size self.inner_dim = config.speculator_config["inner_dim"] self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight def forward( self, @@ -113,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module): for i in range(self.n_predict): # Project and predict z = self.emb[i](ind) - z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d + # z = z.mul(self.emb_weight) # b k d state = self.proj[i](state) * self.state_weight + z state = self.activation(self.ln[i](state)) # b k d probs = F.log_softmax(self.head[i](state), dim=-1) # b k v @@ -136,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module): class MLPSpeculatorHead(nn.Module): - def __init__(self, lm_head, mlp_speculator): + def __init__(self, lm_head, mlp_speculator, scale_input: bool): super().__init__() self.lm_head = lm_head self.mlp_speculator = mlp_speculator + self.scale_input = scale_input def forward( self, input: torch.Tensor @@ -150,6 +248,8 @@ class MLPSpeculatorHead(nn.Module): return logits, None input_ids = logits.argmax(dim=-1) + if self.scale_input: + input = simple_norm(input) speculative_logits = self.mlp_speculator(input, input_ids) return logits, speculative_logits @@ -171,6 +271,12 @@ class MLPSpeculatorHead(nn.Module): ) routing[k] = filename - mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + tie_weights = config.speculator_config.get("tie_weights", False) + if tie_weights: + mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights) + else: + mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator + scale_input = config.speculator_config.get("scale_input", False) lm_head = TensorParallelHead.load(config, prefix, weights) - return MLPSpeculatorHead(lm_head, mlp_speculator) + return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e03cc30d..52f332c1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -458,6 +458,11 @@ def get_model( revision=mlp_revision, filename=filename, ) + speculator_dir_path = Path(mlp_speculator_config).parent + # if these are downloaded, they get converted to safetensors + filenames.extend( + [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)] + ) speculator = { "path": Path(mlp_speculator_config).parent, "model_paths": filenames, From e4ab8554803eb1b6f0cf2b546dd5b4e7176bd99d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 2 Sep 2024 09:27:10 +0200 Subject: [PATCH 18/32] nix: improve impure devshell (#2478) - Add some test dependencies. - Install server in venv. - Install Python client in venv. --- flake.nix | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flake.nix b/flake.nix index 0739c90a..82f9ab1d 100644 --- a/flake.nix +++ b/flake.nix @@ -79,16 +79,22 @@ ] ++ (with python3.pkgs; [ venvShellHook + docker pip ipdb + pytest + pytest-asyncio + syrupy ]); inputsFrom = [ server ]; venvDir = "./.venv"; - postVenv = '' + postVenvCreation = '' unset SOURCE_DATE_EPOCH + ( cd server ; python -m pip install --no-dependencies -e . ) + ( cd clients/python ; python -m pip install --no-dependencies -e . ) ''; postShellHook = '' unset SOURCE_DATE_EPOCH From de2cdeca530f54f5605c71452f9087c1d1a9fd5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 2 Sep 2024 11:31:36 +0200 Subject: [PATCH 19/32] nix: add punica-kernels (#2477) Enables LoRA support. --- flake.lock | 6 +++--- nix/server.nix | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index c0a696b1..c5b3b1ff 100644 --- a/flake.lock +++ b/flake.lock @@ -944,11 +944,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1724784743, - "narHash": "sha256-NdEoWeNwR/ZstYnHaiQWIYZvr7VsrAh7g3+ZHUPrxuI=", + "lastModified": 1725011596, + "narHash": "sha256-zfq8lOXFgJnKxxsqSelHuKUvhxgH3cEmLoAgsOO62Cg=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "c9580c3e39a855246bb87b584bbea1885b44f524", + "rev": "717c2b07e38538abf05237cca65b2d1363c2c9af", "type": "github" }, "original": { diff --git a/nix/server.nix b/nix/server.nix index 6ee088e0..cfdb3f01 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -26,6 +26,7 @@ opentelemetry-instrumentation-grpc, opentelemetry-semantic-conventions, peft, + punica-kernels, safetensors, tokenizers, torch, @@ -92,6 +93,7 @@ buildPythonPackage { opentelemetry-instrumentation-grpc opentelemetry-semantic-conventions peft + punica-kernels safetensors sentencepiece tokenizers From 47d7e344587198ded8a6c89e481b35f4d847fbcf Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 2 Sep 2024 10:00:52 -0400 Subject: [PATCH 20/32] fix: enable chat requests in vertex endpoint (#2481) * fix: enable chat requests in vertex endpoint * feat: avoid unwrap and pre allocate future vec --- router/src/lib.rs | 9 ++- router/src/server.rs | 146 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 124 insertions(+), 31 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index a1e1dadf..d8029c72 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -55,13 +55,20 @@ impl std::str::FromStr for Attention { } #[derive(Clone, Deserialize, ToSchema)] -pub(crate) struct VertexInstance { +pub(crate) struct GenerateVertexInstance { #[schema(example = "What is Deep Learning?")] pub inputs: String, #[schema(nullable = true, default = "null", example = "null")] pub parameters: Option, } +#[derive(Clone, Deserialize, ToSchema)] +#[serde(untagged)] +enum VertexInstance { + Generate(GenerateVertexInstance), + Chat(ChatRequest), +} + #[derive(Deserialize, ToSchema)] pub(crate) struct VertexRequest { #[serde(rename = "instances")] diff --git a/router/src/server.rs b/router/src/server.rs index d3d34215..fac56a77 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::{default_tool_prompt, ChatTokenizeResponse}; +use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -1406,12 +1406,12 @@ async fn vertex_compatibility( )); } - // Process all instances - let predictions = req - .instances - .iter() - .map(|instance| { - let generate_request = GenerateRequest { + // Prepare futures for all instances + let mut futures = Vec::with_capacity(req.instances.len()); + + for instance in req.instances.iter() { + let generate_request = match instance { + VertexInstance::Generate(instance) => GenerateRequest { inputs: instance.inputs.clone(), add_special_tokens: true, parameters: GenerateParameters { @@ -1422,31 +1422,117 @@ async fn vertex_compatibility( decoder_input_details: true, ..Default::default() }, - }; + }, + VertexInstance::Chat(instance) => { + let ChatRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + guideline, + presence_penalty, + frequency_penalty, + top_p, + top_logprobs, + .. + } = instance.clone(); - async { - generate_internal( - Extension(infer.clone()), - compute_type.clone(), - Json(generate_request), - span.clone(), - ) - .await - .map(|(_, Json(generation))| generation.generated_text) - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".into(), - error_type: "Incomplete generation".into(), - }), - ) - }) + let repetition_penalty = presence_penalty.map(|x| x + 2.0); + let max_new_tokens = max_tokens.or(Some(100)); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); + let stop = stop.unwrap_or_default(); + // enable greedy only when temperature is 0 + let (do_sample, temperature) = match temperature { + Some(temperature) if temperature == 0.0 => (false, None), + other => (true, other), + }; + let (inputs, grammar, _using_tools) = match prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + guideline, + messages, + ) { + Ok(result) => result, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: format!("Failed to prepare chat input: {}", e), + error_type: "Input preparation error".to_string(), + }), + )); + } + }; + + GenerateRequest { + inputs: inputs.to_string(), + add_special_tokens: false, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty, + frequency_penalty, + top_k: None, + top_p, + typical_p: None, + do_sample, + max_new_tokens, + return_full_text: None, + stop, + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: top_logprobs, + grammar, + adapter_id: model.filter(|m| *m != "tgi").map(String::from), + }, + } } - }) - .collect::>() - .try_collect::>() - .await?; + }; + + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); + let span_clone = span.clone(); + + futures.push(async move { + generate_internal( + Extension(infer_clone), + compute_type_clone, + Json(generate_request), + span_clone, + ) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + }); + } + + // execute all futures in parallel, collect results, returning early if any error occurs + let results = futures::future::join_all(futures).await; + let predictions: Result, _> = results.into_iter().collect(); + let predictions = predictions?; let response = VertexResponse { predictions }; Ok((HeaderMap::new(), Json(response)).into_response()) From 6cb42f49ae47a117e8f1bdfcdb5cbe42332dc360 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 2 Sep 2024 13:09:06 -0400 Subject: [PATCH 21/32] feat: support lora revisions and qkv_proj weights (#2482) * feat: support lora revisions and qkv_proj weights * fix: add qkv_proj weights to weight test --- server/tests/utils/test_adapter.py | 62 ++++++++++++++++++- .../text_generation_server/models/__init__.py | 3 +- .../custom_modeling/flash_llama_modeling.py | 5 +- .../text_generation_server/utils/adapter.py | 41 +++++++++--- 4 files changed, 100 insertions(+), 11 deletions(-) diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py index cc1b076d..a27c1055 100644 --- a/server/tests/utils/test_adapter.py +++ b/server/tests/utils/test_adapter.py @@ -1,6 +1,54 @@ import pytest from unittest.mock import Mock -from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights +from text_generation_server.utils.adapter import ( + get_attn_weights, + get_mlp_weights, + parse_lora_adapters, + AdapterInfo, +) + + +def test_parse_lora_adapters_empty(): + assert parse_lora_adapters(None) == [] + assert parse_lora_adapters("") == [] + + +def test_parse_lora_adapters_single(): + result = parse_lora_adapters("adapter1") + assert result == [AdapterInfo(id="adapter1", path=None, revision=None)] + + +def test_parse_lora_adapters_with_path(): + result = parse_lora_adapters("adapter1=path/to/adapter1") + assert result == [ + AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None) + ] + + +def test_parse_lora_adapters_with_path_and_revision(): + result = parse_lora_adapters("adapter1=path/to/adapter1@main") + assert result == [ + AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main") + ] + + +def test_parse_lora_adapters_multiple(): + result = parse_lora_adapters( + "adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev" + ) + assert result == [ + AdapterInfo(id="adapter1", path=None, revision=None), + AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None), + AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"), + ] + + +def test_parse_lora_adapters_invalid_format(): + try: + parse_lora_adapters("adapter1,invalid=format=test,adapter3") + assert False, "Should have raised ValueError" + except ValueError as e: + assert str(e) == "Invalid LoRA adapter format: invalid=format=test" def test_get_attn_weights(): @@ -22,6 +70,10 @@ def test_get_attn_weights(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, @@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, @@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 52f332c1..fc530b38 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1259,6 +1259,7 @@ def get_model_with_lora_adapters( "gate_proj", "up_proj", "down_proj", + "qkv_proj", ] for layer_name in adapter_layers: @@ -1286,7 +1287,7 @@ def get_model_with_lora_adapters( if len(unused_weight_names) > 0: logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}" ) if adapter_tokenizer is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5b228f9f..ae981c9a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -66,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id): prefixes = None if config.model_type == "phi3": - prefix = f"{prefix}.qkv_proj" base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=prefix, + prefix=f"{prefix}.qkv_proj", weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) + prefixes = ["qkv_proj"] elif config.model_type == "baichuan": prefix = f"{prefix}.W_pack" base_layer = TensorParallelColumnLinear.load_qkv( @@ -85,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id): num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) + prefixes = [prefix] else: prefixes = ["q_proj", "k_proj", "v_proj"] sizes = [ diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 1db5f77b..b7fc89df 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -3,6 +3,7 @@ # License: Apache License Version 2.0, January 2004 import warnings +import re from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Set, Tuple, Optional, List @@ -27,6 +28,7 @@ BASE_MODEL_ADAPTER_ID = "__base_model__" class AdapterInfo: id: str path: Optional[str] + revision: Optional[str] = None @dataclass @@ -51,11 +53,16 @@ def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]: adapter_list = [] for adapter in lora_adapters.split(","): - parts = adapter.strip().split("=") - if len(parts) == 1: - adapter_list.append(AdapterInfo(id=parts[0], path=None)) - elif len(parts) == 2: - adapter_list.append(AdapterInfo(id=parts[0], path=parts[1])) + adapter = adapter.strip() + if adapter.count("=") > 1 or adapter.count("@") > 1: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter) + + if match: + adapter_id, path, revision = match.groups() + adapter_list.append( + AdapterInfo(id=adapter_id, path=path, revision=revision) + ) else: raise ValueError(f"Invalid LoRA adapter format: {adapter}") return adapter_list @@ -73,6 +80,7 @@ def load_and_merge_adapters( adapter_info = next(iter(adapter_parameters.adapter_info)) return load_module_map( model_id, + adapter_info.revision, adapter_info.id, adapter_info.path, weight_names, @@ -80,7 +88,13 @@ def load_and_merge_adapters( ) adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) - return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code) + return _load_and_merge( + model_id, + adapter_params.revision, + adapter_params, + weight_names, + trust_remote_code, + ) @dataclass @@ -95,6 +109,7 @@ class AdapterParametersContainer: @lru_cache(maxsize=32) def _load_and_merge( model_id: str, + revision: str, adapter_params: AdapterParametersContainer, weight_names: Tuple[str], trust_remote_code: bool = False, @@ -171,12 +186,12 @@ def check_architectures( @lru_cache(maxsize=128) def load_module_map( model_id: str, + revision: str, adapter_id: str, adapter_path: Optional[str], weight_names: Tuple[str], trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - revision = "main" adapter_config = LoraConfig.load(adapter_path or adapter_id, None) @@ -191,6 +206,12 @@ def load_module_map( ) ) + # throw an error if no adapter weights are found + if not adapter_filenames: + raise FileNotFoundError( + f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'." + ) + try: adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_config.config_path, @@ -221,6 +242,12 @@ def get_attn_weights(i, layer): value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) weights[key] = value + # also add the qkv_proj weight for the adapter + weights[(i, "qkv_proj")] = ( + f"model.layers.{i}.self_attn.qkv_proj", + qkv, + ) + weights[(i, "o_proj")] = ( f"model.layers.{i}.self_attn.o_proj", layer.self_attn.o_proj, From deec30f89307c7e51ccc609fb2d0ce1e920505b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 5 Sep 2024 15:09:29 +0200 Subject: [PATCH 22/32] hotfix: avoid non-prefilled block use when using prefix caching (#2489) The minimum batch size logic could cause prefix blocks to be deallocated without prefill. The next allocation of the same prefix would then use garbage blocks. --- backends/v3/src/backend.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 05a26370..a47e62dc 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -122,7 +122,7 @@ impl Backend for BackendV3 { #[allow(clippy::too_many_arguments)] pub(crate) async fn batching_task( mut client: ShardedClient, - waiting_served_ratio: f32, + _waiting_served_ratio: f32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, @@ -168,7 +168,10 @@ pub(crate) async fn batching_task( None } else { // Minimum batch size - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. + // Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + None }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); From 8b96a18265bec0633e39f5930e81afe3a3bb1463 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 5 Sep 2024 16:11:52 +0200 Subject: [PATCH 23/32] Adding links to Adyen blogpost. (#2492) --- README.md | 2 ++ docs/source/conceptual/streaming.md | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/README.md b/README.md index 803e9172..cf6a30db 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,8 @@ overridden with the `--otlp-service-name` argument ![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) +Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) + ### Local install You can also opt to install `text-generation-inference` locally. diff --git a/docs/source/conceptual/streaming.md b/docs/source/conceptual/streaming.md index 49c48fa0..f1f37f2a 100644 --- a/docs/source/conceptual/streaming.md +++ b/docs/source/conceptual/streaming.md @@ -1,5 +1,6 @@ # Streaming + ## What is Streaming? Token streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience. @@ -154,3 +155,7 @@ SSEs are different than: * Webhooks: where there is a bi-directional connection. The server can send information to the client, but the client can also send data to the server after the first request. Webhooks are more complex to operate as they don’t only use HTTP. If there are too many requests at the same time, TGI returns an HTTP Error with an `overloaded` error type (`huggingface_hub` returns `OverloadedError`). This allows the client to manage the overloaded server (e.g., it could display a busy error to the user or retry with a new request). To configure the maximum number of concurrent requests, you can specify `--max_concurrent_requests`, allowing clients to handle backpressure. + +## External sources + +Adyen wrote a nice recap of how TGI streaming feature works. [LLM inference at scale with TGI](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) From e279b38aca90cddc0ab654e18c369d9c462ebc0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 5 Sep 2024 17:06:54 +0200 Subject: [PATCH 24/32] Add two handy gitignores for Nix environments (#2484) --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index f79d8faa..edcc2f89 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp data/ load_tests/*.json server/fbgemmm + +.direnv/ +.venv/ From 5cd8025f1849bd4c13edcf9eb4f72e199e6a5c37 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 5 Sep 2024 23:41:39 +0800 Subject: [PATCH 25/32] hotfix: fix regression of attention api change in intel platform (#2439) fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache format kv input now. Signed-off-by: Wang, Yi A --- Dockerfile_intel | 3 +++ .../layers/attention/ipex.py | 22 +++++++++---------- .../custom_modeling/flash_cohere_modeling.py | 4 ++-- .../custom_modeling/flash_dbrx_modeling.py | 4 ++-- .../flash_deepseek_v2_modeling.py | 4 ++-- .../custom_modeling/flash_gemma2_modeling.py | 6 ++--- .../custom_modeling/flash_gemma_modeling.py | 6 ++--- .../custom_modeling/flash_gpt2_modeling.py | 6 ++--- .../custom_modeling/flash_gptj_modeling.py | 7 +++--- .../custom_modeling/flash_llama_modeling.py | 4 ++-- .../custom_modeling/flash_mistral_modeling.py | 4 ++-- .../custom_modeling/flash_mixtral_modeling.py | 4 ++-- .../custom_modeling/flash_neox_modeling.py | 6 ++--- .../custom_modeling/flash_phi_modeling.py | 5 +++-- .../custom_modeling/flash_qwen2_modeling.py | 5 +++-- .../custom_modeling/flash_rw_modeling.py | 16 ++++++-------- .../flash_santacoder_modeling.py | 5 +++-- .../flash_starcoder2_modeling.py | 5 +++-- 18 files changed, 60 insertions(+), 56 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 9af6422c..0cda4d4b 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index d7cf780a..2d1427ae 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False def attention( - q, - k, - v, - cu_seqlens, - max_s, + q: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -23,13 +23,13 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. ipex.llm.functional.varlen_attention( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_q, 0.0, softmax_scale, False, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index fe19180a..374ccb10 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index b82b5473..0dc88098 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0585b40e..f62dfe66 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index d16e805f..e12bff00 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -25,7 +25,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 34be4cb8..77ae4b35 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -25,7 +25,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 403fa908..411c4ce1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 35ab2791..ef071d46 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_attention(config, prefix: str, weights): @@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index ae981c9a..7d639e35 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 30ca3faf..cdd23796 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index c5d60af1..c36e97f6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index fda648f9..454e45eb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -26,7 +26,7 @@ from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module): # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], + kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 37adb8be..e2d9bbbc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.import_utils import SYSTEM class PhiConfig(PretrainedConfig): @@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module): if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 5aac28a3..999b72e7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.import_utils import SYSTEM def load_attention(config, prefix, weights): @@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 1c55dd91..edc54c09 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -5,7 +5,7 @@ import torch.distributed from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( SpeculativeHead, TensorParallelColumnLinear, @@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=2, index=0), - torch.select(kv, dim=2, index=1), - kv_cache[0], - kv_cache[1], - cu_seqlen_prefill, - max_s, + kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), + kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), + seqlen, + block_tables, self.softmax_scale, ) # Decode diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 19025c4c..f97b4409 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from text_generation_server.utils.import_utils import SYSTEM def load_multi_mqa( @@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], + kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 2f9ecd0d..6aa7fa21 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +from text_generation_server.utils.import_utils import SYSTEM class Starcoder2Config(PretrainedConfig): @@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, From 0424e27f651bf6df492c9ad0ba7c7e9def60f224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 6 Sep 2024 10:19:04 +0200 Subject: [PATCH 26/32] nix: add pyright/ruff for proper LSP in the impure devshell (#2496) We need this to ensure that pyright/ruff are part of the same interpreter/venv. --- flake.nix | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flake.nix b/flake.nix index 82f9ab1d..58a8e311 100644 --- a/flake.nix +++ b/flake.nix @@ -82,8 +82,10 @@ docker pip ipdb + pyright pytest pytest-asyncio + ruff syrupy ]); From 2eb57a15ecc39bdaebef47dbad30293ac82e6a25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 6 Sep 2024 11:00:52 +0200 Subject: [PATCH 27/32] Fix incompatibility with latest `syrupy` and update in Poetry (#2497) --- integration-tests/conftest.py | 8 +++++++- server/poetry.lock | 14 ++++++++++++++ server/pyproject.toml | 1 + server/requirements_cuda.txt | 6 ++++++ server/requirements_intel.txt | 6 ++++++ server/requirements_rocm.txt | 6 ++++++ 6 files changed, 40 insertions(+), 1 deletion(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index a8a77cd2..f58f5fdf 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -64,6 +64,7 @@ class ResponseComparator(JSONSnapshotExtension): self, data, *, + include=None, exclude=None, matcher=None, ): @@ -79,7 +80,12 @@ class ResponseComparator(JSONSnapshotExtension): data = [d.model_dump() for d in data] data = self._filter( - data=data, depth=0, path=(), exclude=exclude, matcher=matcher + data=data, + depth=0, + path=(), + exclude=exclude, + include=include, + matcher=matcher, ) return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n" diff --git a/server/poetry.lock b/server/poetry.lock index fc1a54a3..49276807 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2945,6 +2945,20 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "syrupy" +version = "4.7.1" +description = "Pytest Snapshot Test Utility" +optional = false +python-versions = ">=3.8.1" +files = [ + {file = "syrupy-4.7.1-py3-none-any.whl", hash = "sha256:be002267a512a4bedddfae2e026c93df1ea928ae10baadc09640516923376d41"}, + {file = "syrupy-4.7.1.tar.gz", hash = "sha256:f9d4485f3f27d0e5df6ed299cac6fa32eb40a441915d988e82be5a4bdda335c8"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9.0.0" + [[package]] name = "texttable" version = "1.7.0" diff --git a/server/pyproject.toml b/server/pyproject.toml index 57deb1b8..66844402 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -47,6 +47,7 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] rich = "^13.7.1" +syrupy = "^4.7.1" [tool.poetry.extras] torch = ["torch"] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index eb521bd6..71291f7b 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -4,6 +4,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11" filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" @@ -15,6 +16,7 @@ hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" +iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" @@ -31,10 +33,12 @@ opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" +pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" +pytest==7.4.4 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" @@ -43,7 +47,9 @@ safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" +syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" +tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index eb521bd6..71291f7b 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -4,6 +4,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11" filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" @@ -15,6 +16,7 @@ hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" +iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" @@ -31,10 +33,12 @@ opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" +pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" +pytest==7.4.4 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" @@ -43,7 +47,9 @@ safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" +syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" +tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index eb521bd6..71291f7b 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -4,6 +4,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11" filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" @@ -15,6 +16,7 @@ hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" +iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" @@ -31,10 +33,12 @@ opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" +pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" +pytest==7.4.4 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" @@ -43,7 +47,9 @@ safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" +syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" +tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" From 379472c4c2e401b1efd66d7d47edc00b96f5ce14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 6 Sep 2024 11:55:23 +0200 Subject: [PATCH 28/32] radix trie: add assertions (#2491) These should all be cheap assertions. Also: * Fixup some comments. * Delete a `remove` that was done unnecessarily twice. --- backends/v3/src/radix.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index b85be00b..bb6582b0 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -73,14 +73,13 @@ impl Allocator for RadixAllocator { let node_id = self .cache_blocks .find(prefill_tokens.as_slice(), &mut blocks); - // Even if this allocation fails below, we need to increase he - // refcount to ensure that the prefix that was found is not evicted. - node_id } else { self.cache_blocks.root_id() }; + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. self.cache_blocks .incref(prefix_node) .expect("Failed to increment refcount"); @@ -303,6 +302,11 @@ impl RadixTrie { node.ref_count -= 1; if node.ref_count == 0 { + assert!( + node.children.is_empty(), + "Nodes with children must have refcount > 0" + ); + self.leaves.insert((node.last_accessed, node_id)); } @@ -330,7 +334,7 @@ impl RadixTrie { /// Evict `n_blocks` from the trie. /// /// Returns the evicted blocks. When the length is less than `n_blocks`, - /// not enough blocks could beevicted. + /// not enough blocks could be evicted. pub fn evict(&mut self, n_blocks: usize) -> Vec { // NOTE: we don't return Result here. If any of the unwrapping fails, // it's a programming error in the trie implementation, not a user @@ -345,6 +349,12 @@ impl RadixTrie { let blocks_needed = n_blocks - evicted.len(); let node = self.nodes.get(node_id).expect("Leave does not exist"); + assert_eq!( + node.ref_count, 0, + "Leaf must have refcount of 0, got {}", + node.ref_count + ); + if blocks_needed >= node.blocks.len() { // We need to evict the whole node if we need more blocks than it has. let node = self.remove_node(node_id); @@ -500,12 +510,16 @@ impl RadixTrie { fn remove_node(&mut self, node_id: NodeId) -> TrieNode { // Unwrap here, passing in an unknown id is a programming error. let node = self.nodes.remove(node_id).expect("Unknown node"); + assert!( + node.children.is_empty(), + "Tried to remove a node with {} children", + node.children.len() + ); let parent_id = node.parent.expect("Attempted to remove root node"); let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); parent.children.remove(&node.key[0]); self.decref(parent_id) .expect("Failed to decrease parent refcount"); - self.nodes.remove(node_id); node } @@ -579,6 +593,9 @@ impl TrieNode { fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + // NOTE: this is the case because the child node was chosen based on + // matching the first character of the key/prefix. + assert!(full > 0, "Prefixes must at least share 1 token"); (full / block_size) * block_size } From a3c9c62dc07a044aeea99b6f80b62a77e3ec384f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 6 Sep 2024 12:47:06 +0200 Subject: [PATCH 29/32] hotfix: add syrupy to the right subproject (#2499) --- integration-tests/poetry.lock | 25 +++++++------------------ integration-tests/pyproject.toml | 2 +- integration-tests/requirements.txt | 3 +-- server/poetry.lock | 19 ------------------- server/pyproject.toml | 1 - server/requirements_cuda.txt | 6 ------ server/requirements_intel.txt | 6 ------ server/requirements_rocm.txt | 6 ------ 8 files changed, 9 insertions(+), 59 deletions(-) diff --git a/integration-tests/poetry.lock b/integration-tests/poetry.lock index 3af99942..8398160e 100644 --- a/integration-tests/poetry.lock +++ b/integration-tests/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -268,16 +268,6 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "colored" -version = "1.4.4" -description = "Simple library for color and formatting to terminal" -optional = false -python-versions = "*" -files = [ - {file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"}, -] - [[package]] name = "docker" version = "6.1.3" @@ -855,18 +845,17 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "syrupy" -version = "4.0.1" +version = "4.7.1" description = "Pytest Snapshot Test Utility" optional = false -python-versions = ">=3.8.1,<4" +python-versions = ">=3.8.1" files = [ - {file = "syrupy-4.0.1-py3-none-any.whl", hash = "sha256:53d3107cc5e18a5def189c721879cea2cdafdee34b879f602133ca08837d0e4b"}, - {file = "syrupy-4.0.1.tar.gz", hash = "sha256:60e3e94782444e0f978cd3b207de32f6da3199b15a2db32eab02f83cebb63ae8"}, + {file = "syrupy-4.7.1-py3-none-any.whl", hash = "sha256:be002267a512a4bedddfae2e026c93df1ea928ae10baadc09640516923376d41"}, + {file = "syrupy-4.7.1.tar.gz", hash = "sha256:f9d4485f3f27d0e5df6ed299cac6fa32eb40a441915d988e82be5a4bdda335c8"}, ] [package.dependencies] -colored = ">=1.3.92,<2.0.0" -pytest = ">=7.0.0,<8.0.0" +pytest = ">=7.0.0,<9.0.0" [[package]] name = "text-generation" @@ -1049,4 +1038,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "421fbce065cb1499c666599cf0fd83a5ce8fb3bed09e83c16c3a3d6953b34026" +content-hash = "f5c65e704b02250d73055cd04efcc22f8fc36144eddfc447a71c3a061748db80" diff --git a/integration-tests/pyproject.toml b/integration-tests/pyproject.toml index 88e9761a..123c1167 100644 --- a/integration-tests/pyproject.toml +++ b/integration-tests/pyproject.toml @@ -7,7 +7,7 @@ authors = ["Nicolas Patry "] [tool.poetry.dependencies] pydantic = "> 2, < 3" python = ">=3.9,<3.13" -syrupy = "4.0.1" +syrupy = "^4.7.1" text-generation = "^0.6.0" pytest = "^7.4.0" pytest-asyncio = "^0.21.1" diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt index 3c2ce11b..f3f0569b 100644 --- a/integration-tests/requirements.txt +++ b/integration-tests/requirements.txt @@ -6,7 +6,6 @@ attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") -colored==1.4.4 ; python_version >= "3.9" and python_version < "3.13" docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13" exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11" filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" @@ -25,7 +24,7 @@ pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13" pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" -syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13" +syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13" tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/poetry.lock b/server/poetry.lock index 49276807..ce5b8a6c 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2945,20 +2945,6 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] -[[package]] -name = "syrupy" -version = "4.7.1" -description = "Pytest Snapshot Test Utility" -optional = false -python-versions = ">=3.8.1" -files = [ - {file = "syrupy-4.7.1-py3-none-any.whl", hash = "sha256:be002267a512a4bedddfae2e026c93df1ea928ae10baadc09640516923376d41"}, - {file = "syrupy-4.7.1.tar.gz", hash = "sha256:f9d4485f3f27d0e5df6ed299cac6fa32eb40a441915d988e82be5a4bdda335c8"}, -] - -[package.dependencies] -pytest = ">=7.0.0,<9.0.0" - [[package]] name = "texttable" version = "1.7.0" @@ -3251,11 +3237,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] diff --git a/server/pyproject.toml b/server/pyproject.toml index 66844402..57deb1b8 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -47,7 +47,6 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] rich = "^13.7.1" -syrupy = "^4.7.1" [tool.poetry.extras] torch = ["torch"] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 71291f7b..eb521bd6 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -4,7 +4,6 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11" filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" @@ -16,7 +15,6 @@ hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" -iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" @@ -33,12 +31,10 @@ opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" -pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" -pytest==7.4.4 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" @@ -47,9 +43,7 @@ safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" -syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" -tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 71291f7b..eb521bd6 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -4,7 +4,6 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11" filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" @@ -16,7 +15,6 @@ hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" -iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" @@ -33,12 +31,10 @@ opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" -pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" -pytest==7.4.4 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" @@ -47,9 +43,7 @@ safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" -syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" -tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 71291f7b..eb521bd6 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -4,7 +4,6 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11" filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" @@ -16,7 +15,6 @@ hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" -iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" @@ -33,12 +31,10 @@ opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" -pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" -pytest==7.4.4 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" @@ -47,9 +43,7 @@ safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" -syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" -tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" From aaea212d0f53929cd3775af3eaf06f4af0a868a5 Mon Sep 17 00:00:00 2001 From: Martin Iglesias Goyanes Date: Fri, 6 Sep 2024 17:00:54 +0200 Subject: [PATCH 30/32] Add links to Adyen blogpost (#2500) * Add links to Adyen blogpost * Adding to toctree. * Update external.md * Update _toctree.yml --------- Co-authored-by: Nicolas Patry --- README.md | 2 +- docs/source/_toctree.yml | 2 ++ docs/source/conceptual/external.md | 4 ++++ docs/source/conceptual/streaming.md | 4 ---- 4 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 docs/source/conceptual/external.md diff --git a/README.md b/README.md index cf6a30db..cc9d523f 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,7 @@ overridden with the `--otlp-service-name` argument ![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) -Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) +Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) ### Local install diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f52fa2ec..b883b36d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -71,6 +71,8 @@ title: How Guidance Works (via outlines) - local: conceptual/lora title: LoRA (Low-Rank Adaptation) + - local: conceptual/external + title: External Resources title: Conceptual Guides diff --git a/docs/source/conceptual/external.md b/docs/source/conceptual/external.md new file mode 100644 index 00000000..9cbe1b5a --- /dev/null +++ b/docs/source/conceptual/external.md @@ -0,0 +1,4 @@ +# External Resources + +- Adyen wrote a detailed article about the interplay between TGI's main components: router and server. +[LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) diff --git a/docs/source/conceptual/streaming.md b/docs/source/conceptual/streaming.md index f1f37f2a..b8154ba4 100644 --- a/docs/source/conceptual/streaming.md +++ b/docs/source/conceptual/streaming.md @@ -155,7 +155,3 @@ SSEs are different than: * Webhooks: where there is a bi-directional connection. The server can send information to the client, but the client can also send data to the server after the first request. Webhooks are more complex to operate as they don’t only use HTTP. If there are too many requests at the same time, TGI returns an HTTP Error with an `overloaded` error type (`huggingface_hub` returns `OverloadedError`). This allows the client to manage the overloaded server (e.g., it could display a busy error to the user or retry with a new request). To configure the maximum number of concurrent requests, you can specify `--max_concurrent_requests`, allowing clients to handle backpressure. - -## External sources - -Adyen wrote a nice recap of how TGI streaming feature works. [LLM inference at scale with TGI](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) From c1fe28d694757a6a90426a83006292dc76512f66 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 6 Sep 2024 17:35:49 +0200 Subject: [PATCH 31/32] Fixing more correctly the invalid drop of the batch. (#2498) --- backends/v3/src/backend.rs | 5 +- backends/v3/src/queue.rs | 101 ++++++++++++++++++++----------------- backends/v3/src/radix.rs | 2 + 3 files changed, 58 insertions(+), 50 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index a47e62dc..935f7980 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -122,7 +122,7 @@ impl Backend for BackendV3 { #[allow(clippy::too_many_arguments)] pub(crate) async fn batching_task( mut client: ShardedClient, - _waiting_served_ratio: f32, + waiting_served_ratio: f32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, @@ -170,8 +170,7 @@ pub(crate) async fn batching_task( // Minimum batch size // TODO: temporarily disable to avoid incorrect deallocation + // reallocation when using prefix caching. - // Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - None + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 2a8c4c53..978a495c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -252,17 +252,14 @@ impl State { let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(Span::current()); - let mut batch_requests = Vec::with_capacity(self.entries.len()); - let mut batch_entries = - IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - + let mut batch = Vec::with_capacity(self.entries.len()); let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; let mut max_blocks = 0; // Pop entries starting from the front of the queue - 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() { + 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -276,7 +273,7 @@ impl State { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length; + prefill_tokens = (batch.len() + 1) as u32 * max_input_length; decode_tokens += entry.request.stopping_parameters.max_new_tokens; let total_tokens = prefill_tokens + decode_tokens + self.speculate; @@ -290,7 +287,7 @@ impl State { } None } - Some(block_allocator) => { + Some(_block_allocator) => { prefill_tokens += entry.request.input_length; let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, @@ -324,23 +321,59 @@ impl State { entry.request.input_ids.clone() }; - match block_allocator.allocate(tokens, input_ids).await { - None => { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: not enough free blocks"); - self.entries.push_front((id, entry)); - break 'entry_loop; - } - Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) - } - } + Some((tokens, input_ids)) } }; + batch.push((id, entry, block_allocation)); + if Some(batch.len()) == max_size { + break; + } + } + // Empty batch + if batch.is_empty() { + tracing::debug!("Filterered out all entries"); + return None; + } + + // XXX We haven't allocated yet, so we're allowed to ditch the results. + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch.len() < min_size { + // Add back entries to the queue in the correct order + for (id, entry, _) in batch.into_iter().rev() { + self.entries.push_front((id, entry)); + } + return None; + } + } + + let mut batch_requests = Vec::with_capacity(self.entries.len()); + let mut batch_entries = + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + + for (id, mut entry, block_allocation) in batch { + let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = + (block_allocation, &self.block_allocator) + { + match block_allocator.allocate(tokens, input_ids).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break; + } + Some(block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + Some(block_allocation) + } + } + } else { + None + }; tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); @@ -400,32 +433,6 @@ impl State { entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); - - // Check if max_size - if Some(batch_requests.len()) == max_size { - break; - } - } - - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - - // Check if our batch is big enough - if let Some(min_size) = min_size { - // Batch is too small - if batch_requests.len() < min_size { - // Add back entries to the queue in the correct order - for r in batch_requests.into_iter().rev() { - let id = r.id; - let entry = batch_entries.remove(&id).unwrap(); - self.entries.push_front((id, entry)); - } - - return None; - } } // Final batch size diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index bb6582b0..1f3bef15 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -89,6 +89,8 @@ impl Allocator for RadixAllocator { let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { From eabbbbda2340d6cab12040ff54481f3b7d633ead Mon Sep 17 00:00:00 2001 From: Vallepu Vamsi Krishna Date: Sat, 7 Sep 2024 16:49:43 +0530 Subject: [PATCH 32/32] Add Directory Check to Prevent Redundant Cloning in Build Process (#2486) Update Makefile-fbgemm Added Directory check for FBGEMM repository cloning. --- server/Makefile-fbgemm | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm index 5f3c0eaa..3b8061a1 100644 --- a/server/Makefile-fbgemm +++ b/server/Makefile-fbgemm @@ -1,7 +1,9 @@ fbgemm_commit := v0.8.0 build-fbgemm: - git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ + @if [ ! -d "fbgemm" ]; then \ + git clone https://github.com/pytorch/FBGEMM.git fbgemm; \ + fi cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \ git submodule update --init --recursive && \ cd fbgemm_gpu && \