From b70ae0969f11bae03a3c6194fc8c592a1d8a65b3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 20 Aug 2024 11:15:30 +0200 Subject: [PATCH 1/9] Prefix caching (#2402) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prefix caching WIP * Fixing prefix attention. * Fixing flashinfer import. * Fixing black. * Fixing medusa (still wrong outputs, but functional). * Just medusa values now. * Fixing medusa without prefix caching. * Fixing prefix caching. * Medusa requires reshaping. * Removing the logs. * Remove router.nix * Fixup: - Remove logs - Disable VLMs (they do not work) - Disable prefix caching when user wants prefill logprobs. * Update flake.lock --------- Co-authored-by: Danièˆl de Kok --- backends/v3/src/queue.rs | 13 +- backends/v3/src/radix.rs | 1 + flake.lock | 6 +- flake.nix | 1 + .../layers/attention/__init__.py | 7 +- .../layers/attention/cuda.py | 27 +- .../{flash_infer.py => flashinfer.py} | 76 ++++++ .../text_generation_server/layers/medusa.py | 2 + .../custom_modeling/flash_cohere_modeling.py | 2 + .../custom_modeling/flash_dbrx_modeling.py | 2 + .../flash_deepseek_v2_modeling.py | 2 + .../custom_modeling/flash_gemma2_modeling.py | 2 + .../custom_modeling/flash_gemma_modeling.py | 2 + .../custom_modeling/flash_gpt2_modeling.py | 2 + .../custom_modeling/flash_llama_modeling.py | 2 + .../custom_modeling/flash_mistral_modeling.py | 2 + .../custom_modeling/flash_mixtral_modeling.py | 2 + .../custom_modeling/flash_neox_modeling.py | 2 + .../custom_modeling/flash_phi_modeling.py | 2 + .../custom_modeling/flash_qwen2_modeling.py | 2 + .../custom_modeling/flash_rw_modeling.py | 4 + .../flash_santacoder_modeling.py | 2 + .../flash_starcoder2_modeling.py | 2 + .../models/flash_causal_lm.py | 252 ++++++++++++++---- .../text_generation_server/models/globals.py | 6 +- .../models/vlm_causal_lm.py | 89 +++++-- 26 files changed, 405 insertions(+), 107 deletions(-) rename server/text_generation_server/layers/attention/{flash_infer.py => flashinfer.py} (65%) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 0fb05a98..faa57c11 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -316,10 +316,15 @@ impl State { + self.speculate - 1; - match block_allocator - .allocate(tokens, entry.request.input_ids.clone()) - .await - { + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() + }; + + match block_allocator.allocate(tokens, input_ids).await { None => { // Entry is over budget // Add it back to the front diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index ef963532..5bac1a31 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -205,6 +205,7 @@ pub struct RadixTrie { /// call that a real time lookup would require. time: u64, } + impl Default for RadixTrie { fn default() -> Self { Self::new() diff --git a/flake.lock b/flake.lock index d0c2adbc..cd5d6d2a 100644 --- a/flake.lock +++ b/flake.lock @@ -900,11 +900,11 @@ ] }, "locked": { - "lastModified": 1723515680, - "narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=", + "lastModified": 1723602049, + "narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3", + "rev": "ea0bf33a11a26a62c60123c49d96011da396602c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 6bfe74e8..299e6b3d 100644 --- a/flake.nix +++ b/flake.nix @@ -84,6 +84,7 @@ grpcio-status grpcio-tools hf-transfer + ipdb loguru mamba-ssm marlin-kernels diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index f9b1715e..56fc5319 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -6,7 +6,12 @@ from .common import Seqlen if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": - from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .cuda import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "ipex": diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 8703eb94..b3b7ea4f 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -76,7 +76,7 @@ def paged_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import decode_state + from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( query.contiguous(), @@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2 if ATTENTION == "flashinfer": def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -231,14 +233,15 @@ if ATTENTION == "flashinfer": causal=True, softcap=0.0, ): - from text_generation_server.layers.attention.flash_infer import prefill_state + assert window_size_left == -1, "Windowing is not supported with flash infer" + from text_generation_server.layers.attention.flashinfer import ( + prefill_with_paged_kv_state, + ) - return prefill_state.get().forward( - q, - k, - v, + return prefill_with_paged_kv_state.get().forward( + q.contiguous(), causal=causal, - window_left=window_size_left, + paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, ) @@ -249,6 +252,8 @@ elif V2: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -289,6 +294,8 @@ else: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flashinfer.py similarity index 65% rename from server/text_generation_server/layers/attention/flash_infer.py rename to server/text_generation_server/layers/attention/flashinfer.py index 56b53b2c..e1ef62c5 100644 --- a/server/text_generation_server/layers/attention/flash_infer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con "prefill_state" ) +prefill_with_paged_kv_state: ContextVar[ + flashinfer.BatchPrefillWithPagedKVCacheWrapper +] = ContextVar("prefill_with_paged_kv_state") + decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( "decode_state" ) @@ -24,6 +28,78 @@ def get_workspace(device): return workspace +def create_prefill_with_paged_kv_state( + *, + device: torch.device, +): + """Create a prefill state that uses the KV cache.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_with_paged_kv_state( + *, + state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, + block_tables: torch.Tensor, + cu_seqlens: torch.Tensor, + input_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + if page_size == 1: + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + else: + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = prefill_with_paged_kv_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + paged_kv_indptr=indptr, + paged_kv_indices=block_tables, + paged_kv_last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + page_size=page_size, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_with_paged_kv_state.reset(token) + + def create_prefill_state( *, device: torch.device, diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 7579ccdb..139c4dc2 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module): ) def forward(self, x): + if not self.heads: + return None speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) return speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e02a31d9..1eb8c6c3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d3d1d1ef..fc0dca5b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0905d3c2..b25becd5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 54d212e6..faf0f325 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 178efadb..33738a59 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a19cff8c..d30b5a0a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9ea19a87..3253d2dc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dda53ff3..5a150267 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 85431c6c..ad426ffe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 67237d5c..b684e035 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0], qkv[:, 1], qkv[:, 2], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9e18348..efe27c13 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 865cc85d..879b8abd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 10f995a3..c72a9b90 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, @@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module): query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c2676782..109304be 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module): query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index e562eb89..200d4ef0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5e2fd20a..dd4203e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,6 +43,7 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, + PREFIX_CACHING, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor max_seqlen: int @@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + # Prefixes + prefix_ids: List[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] all_input_ids = [] + prefix_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True @@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 + cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] + prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch): ): tokenized_input = tokenized_input[1:] + orig_input_length = len(tokenized_input) + + if PREFIX_CACHING: + prefix_len = r.prefix_len + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 + else: + prefix_len = 0 + + prefix_ids.append(tokenized_input[:prefix_len]) + tokenized_input = tokenized_input[prefix_len:] + input_length = len(tokenized_input) input_lengths.append(input_length) @@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + request_position_ids = torch.arange( + prefix_len, orig_input_length, dtype=torch.int32 + ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs @@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch): # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to blocks. + block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to slots. We don't need slots for the + # cached prefix (if present). + slot_tokens = input_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] @@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots + request_slots = r.slots[ + prefix_len: #: orig_input_length + max_new_tokens + speculative_length + ] block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) + + slots.extend(request_slots) + prefix_lens.append(prefix_len) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) + start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length - cumulative_max_length += total_tokens + cumulative_slot_tokens += slot_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch): ) slots = torch.tensor(slots, dtype=torch.int64, device=device) + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] + prefix_lens = [] prefix_offsets = [] read_offsets = [] @@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] + prefix_len = self.prefix_lens[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) + prefix_ids.append(self.prefix_ids[idx]) input_lengths.append(request_input_length) + prefix_lens.append(prefix_len) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] + prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( @@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) + prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) @@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] + prefix_lens = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] prefix_offsets = [] @@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) + prefix_lens.extend(batch.prefix_lens) all_input_ids.extend(batch.all_input_ids) + prefix_ids.extend(batch.prefix_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) @@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch): slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, @@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -970,19 +1028,22 @@ class FlashCausalLM(Model): self.kv_cache = [] if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_prefill_state, create_decode_state, + create_prefill_with_paged_kv_state, ) self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) - if not CUDA_GRAPHS: - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) super().__init__( model_id=model_id, @@ -1074,12 +1135,23 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) + input_lengths = [max_s] * bs + prefix_lengths = [0] * bs + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) + prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lengths, + ) self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1087,14 +1159,14 @@ class FlashCausalLM(Model): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "input_lengths": input_lengths, + "input_lengths": input_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, ) @@ -1104,7 +1176,7 @@ class FlashCausalLM(Model): last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) state = create_decode_state_cuda_graphs( device=input_ids.device, - block_tables=block_tables.view(-1), + block_tables=block_tables, block_tables_ptr=block_tables_ptr, last_page_len=last_page_len, num_heads=self.num_heads, @@ -1120,7 +1192,10 @@ class FlashCausalLM(Model): block_tables=block_tables, cu_seqlen_prefill=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, state=state, + prefix_lens=prefix_lengths, + prefix_lens_tensor=prefix_lengths_tensor, ): self.model.forward( input_ids=input_ids, @@ -1138,7 +1213,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1146,7 +1221,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_tensor, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1334,6 +1409,9 @@ class FlashCausalLM(Model): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -1354,6 +1432,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1372,10 +1451,20 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths=input_lengths, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, ): input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( @@ -1399,20 +1488,32 @@ class FlashCausalLM(Model): # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) - state = cuda_graph.get("state") with self._forward_context( - block_tables=block_tables, + block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - input_lengths=input_lengths, - state=state, + input_lengths=batch.input_lengths, + input_lengths_tensor=cuda_graph["input_lengths"], + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + state=cuda_graph.get("state"), ): # Replay the graph cuda_graph["graph"].replay() @@ -1610,6 +1711,7 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, + batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -1627,6 +1729,7 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, + prefix_ids, do_sample, seed, top_n_tokens, @@ -1701,18 +1804,18 @@ class FlashCausalLM(Model): out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] + request_prefill_logprobs = ( + [float("nan")] * (len(prefix_ids) + 1) + ) + prefill_logprobs[out_start_index : out_end_index - 1] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, + prefix_ids + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = Tokens( - prefill_token_ids, + prefix_ids + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], @@ -1794,33 +1897,68 @@ class FlashCausalLM(Model): *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths: torch.Tensor, + input_lengths: List[int], + input_lengths_tensor: torch.Tensor, + prefix_lens: List[int], + prefix_lens_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( use_decode_state, - use_prefill_state, + use_prefill_with_paged_kv_state, ) + # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + if cu_seqlen_prefill is not None: - return use_prefill_state( - state=state if state is not None else self.prefill_state, + return use_prefill_with_paged_kv_state( + state=( + state if state is not None else self.prefill_with_paged_kv_state + ), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - ) - else: - assert input_lengths is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths, - block_tables=block_tables.view(-1), + input_lengths=input_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, ) + else: + assert input_lengths_tensor is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths_tensor, + block_tables=block_tables, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) + + +def block_tables_to_ragged( + *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(prefix_lens) + + total_len = sum(input_lengths) + sum(prefix_lens) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + offset = 0 + for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): + seq_len = prefix_len + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index abc35421..d5133f5e 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,8 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) -log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") - +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} +log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( @@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer": else: BLOCK_SIZE = 16 - cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7de54aa4..2ed1a119 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + block_tables_to_ragged, ) +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM): trust_remote_code: bool, **kwargs, ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") if processor_kwargs is None: processor_kwargs = {} self.processor = processor_class.from_pretrained( @@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -349,43 +357,68 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) # Replay the graph cuda_graph["graph"].replay() From f5f11b797e70b2232632d410273c5c4418475dd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 20 Aug 2024 22:07:33 +0200 Subject: [PATCH 2/9] nix: add pure server to flake, add both pure and impure devshells (#2430) * nix: pure server and support both pure and impure devShells * nix: remove unused poetry2nix input It is not wired up and we now have a pure server. * nix: add ipdb to impure devshell --- flake.lock | 115 ------------------------------------------------- flake.nix | 64 +++++++++++---------------- nix/server.nix | 105 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 155 deletions(-) create mode 100644 nix/server.nix diff --git a/flake.lock b/flake.lock index cd5d6d2a..c20dd98a 100644 --- a/flake.lock +++ b/flake.lock @@ -492,24 +492,6 @@ "type": "github" } }, - "flake-utils_7": { - "inputs": { - "systems": "systems_7" - }, - "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, "gitignore": { "inputs": { "nixpkgs": [ @@ -594,27 +576,6 @@ "type": "github" } }, - "nix-github-actions": { - "inputs": { - "nixpkgs": [ - "poetry2nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1703863825, - "narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=", - "owner": "nix-community", - "repo": "nix-github-actions", - "rev": "5163432afc817cf8bd1f031418d1869e4c9d5547", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "nix-github-actions", - "type": "github" - } - }, "nix-test-runner": { "flake": false, "locked": { @@ -753,31 +714,6 @@ "type": "github" } }, - "poetry2nix": { - "inputs": { - "flake-utils": "flake-utils_7", - "nix-github-actions": "nix-github-actions", - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ], - "systems": "systems_8", - "treefmt-nix": "treefmt-nix" - }, - "locked": { - "lastModified": 1723854676, - "narHash": "sha256-+BrHfNuXrqeE7PoV6xDaoh0joYiJkvTTCIV0fFR3THw=", - "owner": "nix-community", - "repo": "poetry2nix", - "rev": "d650118bce34c0238b9b54f23f7f173f9e4db867", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "poetry2nix", - "type": "github" - } - }, "pre-commit-hooks": { "inputs": { "flake-compat": [ @@ -887,7 +823,6 @@ "tgi-nix", "nixpkgs" ], - "poetry2nix": "poetry2nix", "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } @@ -1003,35 +938,6 @@ "type": "github" } }, - "systems_7": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - }, - "systems_8": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "id": "systems", - "type": "indirect" - } - }, "tgi-nix": { "inputs": { "flake-compat": "flake-compat_4", @@ -1050,27 +956,6 @@ "repo": "tgi-nix", "type": "github" } - }, - "treefmt-nix": { - "inputs": { - "nixpkgs": [ - "poetry2nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1719749022, - "narHash": "sha256-ddPKHcqaKCIFSFc/cvxS14goUhCOAwsM1PbMr0ZtHMg=", - "owner": "numtide", - "repo": "treefmt-nix", - "rev": "8df5ff62195d4e67e2264df0b7f5e8c9995fd0bd", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "treefmt-nix", - "type": "github" - } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 299e6b3d..adc70fd1 100644 --- a/flake.nix +++ b/flake.nix @@ -8,10 +8,6 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; - poetry2nix = { - url = "github:nix-community/poetry2nix"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; - }; rust-overlay = { url = "github:oxalica/rust-overlay"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; @@ -26,7 +22,6 @@ flake-utils, rust-overlay, tgi-nix, - poetry2nix, }: flake-utils.lib.eachDefaultSystem ( system: @@ -47,14 +42,28 @@ tgi-nix.overlay ]; }; - inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; - text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; + launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override { + inherit crateOverrides; + }; + router = cargoNix.workspaceMembers.text-generation-router-v3.build.override { + inherit crateOverrides; + }; + server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; }; in { - devShells.default = - with pkgs; - mkShell { + devShells = with pkgs; rec { + default = pure; + + pure = mkShell { + buildInputs = [ + launcher + router + server + ]; + }; + + impure = mkShell { buildInputs = [ openssl.dev @@ -65,43 +74,16 @@ "rust-src" ]; }) + protobuf ] ++ (with python3.pkgs; [ venvShellHook pip - - causal-conv1d - click - einops - exllamav2 - fbgemm-gpu - flashinfer - flash-attn - flash-attn-layer-norm - flash-attn-rotary - grpc-interceptor - grpcio-reflection - grpcio-status - grpcio-tools - hf-transfer ipdb - loguru - mamba-ssm - marlin-kernels - opentelemetry-api - opentelemetry-exporter-otlp - opentelemetry-instrumentation-grpc - opentelemetry-semantic-conventions - peft - tokenizers - torch - transformers - vllm - - (cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; }) - (cargoNix.workspaceMembers.text-generation-router-v3.build.override { inherit crateOverrides; }) ]); + inputsFrom = [ server ]; + venvDir = "./.venv"; postVenv = '' @@ -109,8 +91,10 @@ ''; postShellHook = '' unset SOURCE_DATE_EPOCH + export PATH=$PATH:~/.cargo/bin ''; }; + }; } ); } diff --git a/nix/server.nix b/nix/server.nix new file mode 100644 index 00000000..ff40757a --- /dev/null +++ b/nix/server.nix @@ -0,0 +1,105 @@ +{ + nix-filter, + buildPythonPackage, + poetry-core, + mypy-protobuf, + causal-conv1d, + einops, + exllamav2, + fbgemm-gpu, + flashinfer, + flash-attn, + flash-attn-layer-norm, + flash-attn-rotary, + grpc-interceptor, + grpcio-reflection, + grpcio-status, + grpcio-tools, + hf-transfer, + loguru, + mamba-ssm, + marlin-kernels, + opentelemetry-api, + opentelemetry-exporter-otlp, + opentelemetry-instrumentation-grpc, + opentelemetry-semantic-conventions, + peft, + safetensors, + tokenizers, + sentencepiece, + transformers, + typer, + vllm, +}: + +let + filter = nix-filter.lib; +in +buildPythonPackage { + name = "text-generation-server"; + + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi"))) + "server/pyproject.toml" + (and (inDirectory "proto/v3") (matchExt "proto")) + ]; + }; + + pyproject = true; + + build-system = [ poetry-core ]; + + nativeBuildInputs = [ mypy-protobuf ]; + + pythonRelaxDeps = [ + "einops" + "huggingface-hub" + "loguru" + "opentelemetry-instrumentation-grpc" + "sentencepiece" + "typer" + ]; + + pythonRemoveDeps = [ "scipy" ]; + + dependencies = [ + causal-conv1d + einops + exllamav2 + fbgemm-gpu + flashinfer + flash-attn + flash-attn-layer-norm + flash-attn-rotary + grpc-interceptor + grpcio-reflection + grpcio-status + grpcio-tools + hf-transfer + loguru + mamba-ssm + marlin-kernels + opentelemetry-api + opentelemetry-exporter-otlp + opentelemetry-instrumentation-grpc + opentelemetry-semantic-conventions + peft + safetensors + sentencepiece + tokenizers + transformers + typer + vllm + ]; + + prePatch = '' + python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \ + --grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto + find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch server/text_generation_server/pb/__init__.py + cd server + ''; +} From 947441509580c56e9bd8160bf2cf23c653d19abd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 21 Aug 2024 07:48:13 +0200 Subject: [PATCH 3/9] nix: add `text-generation-benchmark` to pure devshell (#2431) nix: add text-generation-benchmark to pure devshell --- flake.nix | 4 +++ nix/crate-overrides.nix | 68 ++++++++++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/flake.nix b/flake.nix index adc70fd1..1c9c77a9 100644 --- a/flake.nix +++ b/flake.nix @@ -43,6 +43,9 @@ ]; }; crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; + benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override { + inherit crateOverrides; + }; launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; }; @@ -57,6 +60,7 @@ pure = mkShell { buildInputs = [ + benchmark launcher router server diff --git a/nix/crate-overrides.nix b/nix/crate-overrides.nix index 343b3b25..cbf366af 100644 --- a/nix/crate-overrides.nix +++ b/nix/crate-overrides.nix @@ -20,34 +20,52 @@ defaultCrateOverrides rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; grpc-metadata = attrs: { - src = - filter { - root = ../backends/grpc-metadata; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + src = filter { + root = ../backends/grpc-metadata; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; - text-generation-launcer = attrs: { - src = - filter { - root = ../launcher; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + text-generation-benchmark = attrs: { + src = filter { + root = ../benchmark; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-client = attrs: { + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "backends/client") (matchExt "rs")) + (and (inDirectory "proto") (matchExt "proto")) + ]; + }; + postPatch = "cd backends/client"; + buildInputs = [ protobuf ]; + }; + text-generation-launcher = attrs: { + src = filter { + root = ../launcher; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; text-generation-router = attrs: { - src = - filter { - root = ../router; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + src = filter { + root = ../router; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; text-generation-router-v3 = attrs: { # We need to do the src/source root dance so that the build From 310778e02a5aa36a4c72601e4e929bbbea0f1e7b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Aug 2024 09:06:33 +0200 Subject: [PATCH 4/9] Adding eetq to flake. (#2438) --- flake.lock | 12 ++++++------ nix/server.nix | 2 ++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flake.lock b/flake.lock index c20dd98a..69f9ef13 100644 --- a/flake.lock +++ b/flake.lock @@ -835,11 +835,11 @@ ] }, "locked": { - "lastModified": 1723602049, - "narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=", + "lastModified": 1724206841, + "narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "ea0bf33a11a26a62c60123c49d96011da396602c", + "rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", "type": "github" }, "original": { @@ -944,11 +944,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1723973328, - "narHash": "sha256-q5FmW4YFQcRb6fXHnrxL0uno6xcw9dcg+pFBbVM1xeQ=", + "lastModified": 1724218652, + "narHash": "sha256-Y7Kt+AZRIdo7tr/VhKGzdwYf7stiYQ4JD7flusEpXQw=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "d2038f36589a8a179834e5771ffd081620ba94c3", + "rev": "ab2761aa7b970e737492b8cc41ca580dcb094808", "type": "github" }, "original": { diff --git a/nix/server.nix b/nix/server.nix index ff40757a..1f90e3fd 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -4,6 +4,7 @@ poetry-core, mypy-protobuf, causal-conv1d, + eetq, einops, exllamav2, fbgemm-gpu, @@ -66,6 +67,7 @@ buildPythonPackage { pythonRemoveDeps = [ "scipy" ]; dependencies = [ + eetq causal-conv1d einops exllamav2 From 358ceb67dd343f4022537a1f28bc3fc9baec9102 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 21 Aug 2024 22:20:03 +0200 Subject: [PATCH 5/9] nix: add awq-inference-engine as server dependency (#2442) --- flake.lock | 6 +++--- nix/server.nix | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 69f9ef13..b40f51b3 100644 --- a/flake.lock +++ b/flake.lock @@ -944,11 +944,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1724218652, - "narHash": "sha256-Y7Kt+AZRIdo7tr/VhKGzdwYf7stiYQ4JD7flusEpXQw=", + "lastModified": 1724270760, + "narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "ab2761aa7b970e737492b8cc41ca580dcb094808", + "rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c", "type": "github" }, "original": { diff --git a/nix/server.nix b/nix/server.nix index 1f90e3fd..4e0fdaa1 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -3,6 +3,7 @@ buildPythonPackage, poetry-core, mypy-protobuf, + awq-inference-engine, causal-conv1d, eetq, einops, @@ -67,6 +68,7 @@ buildPythonPackage { pythonRemoveDeps = [ "scipy" ]; dependencies = [ + awq-inference-engine eetq causal-conv1d einops From f3c5d7d92f005c3cd6a801a33334fb9ba32f55f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 23 Aug 2024 22:06:22 +0200 Subject: [PATCH 6/9] nix: add default package (#2453) The default package wraps the launcher and puts the server/router in the path. As a result, TGI can be started using something like: ``` nix run .# -- \ --model-id hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 \ --port 8080 ``` --- flake.nix | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/flake.nix b/flake.nix index 1c9c77a9..83feb26a 100644 --- a/flake.nix +++ b/flake.nix @@ -99,6 +99,17 @@ ''; }; }; + + packages.default = pkgs.writeShellApplication { + name = "text-generation-inference"; + runtimeInputs = [ + server + router + ]; + text = '' + ${launcher}/bin/text-generation-launcher "$@" + ''; + }; } ); } From 30be188400d27b6fedd88cb3dfd88de45639703c Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Aug 2024 17:04:46 -0400 Subject: [PATCH 7/9] Fix: don't apply post layernorm in SiglipVisionTransformer (#2459) * Fix: don't apply post layernorm in SiglipVisionTransformer This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813). This also makes Siglip consistent with the existing Clip implementation: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613 * fix: adjust pali gemma for post layer norm and small refactors --------- Co-authored-by: Travis Addair --- .../custom_modeling/flash_pali_gemma_modeling.py | 10 +++++++++- .../models/custom_modeling/siglip.py | 13 +------------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index d10efb41..e08a2aad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module): config=config.vision_config, weights=weights, ) + self.post_vision_tower_layernorm = nn.LayerNorm.load( + prefix="vision_tower.vision_model.post_layernorm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) self.multi_modal_projector = TensorParallelColumnLinear.load( config, @@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): if pixel_values is not None: pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) image_outputs = self.vision_tower(pixel_values) - image_features = self.multi_modal_projector(image_outputs.last_hidden_state) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) # mask where image or padding tokens mask = input_ids == self.config.image_token_index diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 480d0f9f..95ac9ede 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module): inputs_embeds, attention_mask: Optional[torch.Tensor] = None, ): - hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): hidden_states, _ = encoder_layer( @@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) - self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, ): - r""" - Returns: - - """ if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module): inputs_embeds=hidden_states, ) last_hidden_state = encoder_outputs - post_last_hidden_state = self.post_layernorm(last_hidden_state) return BaseModelOutputWithPooling( - last_hidden_state=post_last_hidden_state, + last_hidden_state=last_hidden_state, # pooler_output=pooled_output, # hidden_states=encoder_outputs, ) From cfa73b5c99bc009903fbc340f8b77a6d4674455d Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Aug 2024 20:19:38 -0400 Subject: [PATCH 8/9] Pr 2451 ci branch (#2454) * fix[router]: Fix tools not passed in chat template Signed-off-by: GitHub * feat: improve default tool serialization and lints * feat: refactor tool logic to include notify_error in prompt and adjust typing * fix: adjust non tool template apply * fix: simplify tool grammar logic and improve schema * feat: avoid skip tool test and avoid empty tool prompts * fix: increase test client timeout for grammar compilation tests --------- Signed-off-by: GitHub Co-authored-by: Simone Rossi --- Cargo.lock | 1 + clients/python/text_generation/client.py | 7 +- docs/openapi.json | 2 +- integration-tests/conftest.py | 2 +- integration-tests/models/test_tools_llama.py | 50 +++--- router/Cargo.toml | 2 +- router/src/infer/chat_template.rs | 57 ++++--- router/src/infer/mod.rs | 6 +- router/src/infer/tool_grammar.rs | 121 +++++++------- router/src/lib.rs | 15 +- router/src/server.rs | 160 ++++++++++++++++--- 11 files changed, 268 insertions(+), 155 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d298c379..aa5cb642 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", + "serde_json", ] [[package]] diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 12966747..45301b63 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -757,7 +757,12 @@ class AsyncClient: continue payload = byte_payload.decode("utf-8") if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + payload_data = ( + payload.lstrip("data:").rstrip("\n").removeprefix(" ") + ) + if payload_data == "[DONE]": + break + json_payload = json.loads(payload_data) try: response = ChatCompletionChunk(**json_payload) yield response diff --git a/docs/openapi.json b/docs/openapi.json index df21e19d..fd64a3ab 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -924,7 +924,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 15af1cad..a8a77cd2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -257,7 +257,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator): class LauncherHandle: def __init__(self, port: int): - self.client = AsyncClient(f"http://localhost:{port}") + self.client = AsyncClient(f"http://localhost:{port}", timeout=30) def _inner_health(self): raise NotImplementedError diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index f831990a..9855cfda 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -36,6 +36,7 @@ tools = [ }, }, "required": ["location", "format"], + "additionalProperties": False, }, }, }, @@ -62,13 +63,13 @@ tools = [ }, }, "required": ["location", "format", "num_days"], + "additionalProperties": False, }, }, }, ] -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna max_tokens=100, seed=1, tools=tools, - presence_penalty=-1.1, + temperature=0.0, messages=[ { "role": "system", @@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="auto", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( @@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 38 + assert count == 48 assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_insufficient_information( @@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information( ): responses = await flash_llama_grammar_tools.chat( max_tokens=100, - seed=8, + seed=24, tools=tools, tool_choice="auto", messages=[ { "role": "system", - "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", }, { "role": "user", @@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( ) assert responses.choices[0].message.content is None - assert responses.choices[0].message.tool_calls == [ - { - "function": { - "arguments": { - "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." - }, - "description": None, - "name": "notify_error", - }, - "id": 0, - "type": "function", - } - ] - + assert ( + responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" + ) assert responses == response_snapshot diff --git a/router/Cargo.toml b/router/Cargo.toml index 7773e212..45acab8e 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } -minijinja = { version = "2.0.2" } +minijinja = { version = "2.0.2", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" regex = "1.10.3" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index a8537818..bfa9421c 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,9 +1,7 @@ use std::collections::HashSet; use crate::infer::InferError; -use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, -}; +use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -32,6 +30,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + tracing::debug!("Loading template: {:#?}", template_str); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) @@ -42,6 +41,7 @@ impl ChatTemplate { let variables = template.undeclared_variables(true); // check if the `tools` variable is used in the template let use_default_tool_template = !variables.contains("tools"); + tracing::debug!("Use default tool template: {}", use_default_tool_template); Self { template, @@ -56,25 +56,36 @@ impl ChatTemplate { &self, guideline: Option<&str>, mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - // check if guideline is expected but not provided if self.variables.contains("guideline") && guideline.is_none() { return Err(InferError::MissingTemplateVariable("guideline".to_string())); } + let tools = match tools_and_prompt { + Some((tools, tool_prompt)) => { + // check if the `tools` variable is used in the template + // if not, we need to append the tools to the last message + let text = if self.use_default_tool_template { + match serde_json::to_string(&tools) { + Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Err(e) => return Err(InferError::ToolError(e.to_string())), + } + } else { + // if the `tools` variable is used in the template, we just append the tool_prompt + format!("\n---\n{}", tool_prompt) + }; + if let Some(last_message) = messages.last_mut() { + last_message.content.push(MessageChunk::Text { text }); + } + Some(tools) + } + None => None, + }; + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + self.template .render(ChatTemplateInputs { guideline, @@ -82,8 +93,7 @@ impl ChatTemplate { bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, - tools: None, - tools_prompt: None, + tools, }) .map_err(InferError::TemplateError) } @@ -95,7 +105,7 @@ mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, }; use minijinja::Environment; @@ -854,11 +864,12 @@ mod tests { content: MessageContent::SingleText("Just testing".to_string()), }, ]; - let tools = serde_json::json!("[]"); + let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); + let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); - let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); - let result = ct.apply(None, msgs, Some(grammer_with_prompt)); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + let tools_and_prompt = Some((tools, tool_prompt)); + let result = ct.apply(None, msgs, tools_and_prompt); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index c9354d9a..81c0d38f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -3,7 +3,7 @@ mod chat_template; pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; -use crate::GrammarType; +use crate::Tool; use crate::{ ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, Message, PrefillToken, Token, @@ -140,12 +140,12 @@ impl Infer { &self, guideline: Option, messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 05027f30..4fe15720 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,5 +1,8 @@ use crate::infer::InferError; -use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use crate::{ + FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, + ToolType, +}; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -16,17 +19,38 @@ impl ToolGrammar { } pub fn apply( - tools: Option>, + tools: Vec, tool_choice: ToolChoice, - ) -> Result, InferError> { + ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None - let tools = match tools { - Some(tools) if !tools.is_empty() => tools, - _ => return Ok(None), - }; + if tools.is_empty() { + return Ok((tools, None)); + } let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + let mut tools = tools.clone(); + + // add the notify_error function to the tools + let notify_error = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "notify_error".to_string(), + description: Some("Notify an error or issue".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "error": { + "type": "string", + "description": "The error or issue to notify" + } + }, + "required": ["error"] + }), + }, + }; + tools.push(notify_error); + // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { ToolType::FunctionName(name) => { @@ -35,87 +59,57 @@ impl ToolGrammar { ToolType::Function { function } => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools, - ToolType::NoTool => return Ok(None), + ToolType::OneOf => tools.clone(), + ToolType::NoTool => return Ok((tools, None)), }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - let functions: HashMap = tools_to_use .iter() .map(|tool| { let func = tool.function.clone(); - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; + let mut params = Map::new(); - // Insert the function's description at the top level, outside of properties params.insert( "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), + Value::String(func.description.unwrap_or_default()), ); - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); + let mut properties = Map::new(); + let mut required = vec![Value::String("_name".to_string())]; - // Insert the constant for the function name inside 'properties' properties.insert( "_name".to_string(), json!({ "type": "string", "const": func.name.clone(), - // "description": "The name of the function" }), ); - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); + if let Value::Object(args) = func.arguments { + if let Some(Value::Object(props)) = args.get("properties") { + properties.extend(props.clone()); + } + if let Some(Value::Array(reqs)) = args.get("required") { + required.extend(reqs.clone()); + } + params.insert( + "additionalProperties".to_string(), + Value::Bool( + args.get("additionalProperties").and_then(|v| v.as_str()) + == Some("true"), + ), + ); } + params.insert("properties".to_string(), Value::Object(properties)); + params.insert("required".to_string(), Value::Array(required)); + (func.name, Value::Object(params)) }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) .collect(); - let tools = Tools { + let tool_schema = JsonSchemaTool { functions_map: FunctionsMap { functions }, properties: Properties { function: tools_to_use @@ -123,13 +117,10 @@ impl ToolGrammar { .map(|tool| FunctionRef { ref_path: format!("#/$functions/{}", tool.function.name.clone()), }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) .collect(), }, }; - Ok(Some(tools)) + Ok((tools, Some(tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 1b2ff153..ce4f7c46 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -840,10 +840,10 @@ pub(crate) struct ChatRequest { pub tools: Option>, /// A prompt to be appended before the tools - #[serde(default = "default_tool_prompt")] + #[serde(default)] #[schema( nullable = true, - example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" + example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." )] pub tool_prompt: Option, @@ -865,10 +865,8 @@ pub(crate) struct ChatRequest { pub guideline: Option, } -fn default_tool_prompt() -> Option { - Some( - "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), - ) +pub fn default_tool_prompt() -> String { + "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] @@ -910,7 +908,7 @@ impl From for ToolChoice { } #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] -pub struct Tools { +pub struct JsonSchemaTool { #[serde(flatten)] functions_map: FunctionsMap, properties: Properties, @@ -968,8 +966,7 @@ pub(crate) struct ChatTemplateInputs<'a> { bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, - tools: Option<&'a str>, - tools_prompt: Option<&'a str>, + tools: Option>, guideline: Option<&'a str>, } diff --git a/router/src/server.rs b/router/src/server.rs index 8ec7a871..8ebd1a33 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::ChatTokenizeResponse; +use crate::{default_tool_prompt, ChatTokenizeResponse}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -23,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -146,7 +146,7 @@ async fn get_chat_tokenize( } = req; let tool_prompt = tool_prompt.unwrap_or_default(); - let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + let (inputs, _grammar, _using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1158,14 +1158,16 @@ async fn chat_completions( let repetition_penalty = presence_penalty.map(|x| x + 2.0); let max_new_tokens = max_tokens.or(Some(100)); let logprobs = logprobs.unwrap_or(false); - let tool_prompt = tool_prompt.unwrap_or_default(); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); let stop = stop.unwrap_or_default(); // enable greedy only when temperature is 0 let (do_sample, temperature) = match temperature { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - let (inputs, grammar, tool_grammar) = prepare_chat_input( + let (inputs, grammar, using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1221,7 +1223,7 @@ async fn chat_completions( }); // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if tool_grammar.is_some() { + let (content, tool_calls) = if using_tools { (None, Some(vec![stream_token.token.text])) } else { let content = if !stream_token.token.special { @@ -1275,7 +1277,7 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let (tool_calls, output) = if tool_grammar.is_some() { + let (tool_calls, output) = if using_tools { let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| { InferError::ToolError(format!( @@ -2539,7 +2541,7 @@ fn create_post_processor( Ok(post_processor) } -type PreparedInput = (String, Option, Option); +type PreparedInput = (String, Option, bool); fn prepare_chat_input( infer: &Infer, @@ -2556,19 +2558,139 @@ fn prepare_chat_input( )); } + // when response_format is set, tools are not included when applying the chat template to generate inputs if let Some(format) = response_format { let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), None)); + return Ok((inputs, Some(format), false)); } - // if tools are set, apply the tool grammar and then the chat template - let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; - let grammar = tool_grammar - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; - Ok((inputs, grammar, tool_grammar)) + // when no response_format is set and tools are included, apply the chat template with the tools + // to generate inputs + if let Some(tools) = tools { + let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; + + let grammar = tool_schema + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt.into())), + )?; + return Ok((inputs, grammar, tool_schema.is_some())); + } + + // if no response_format or tools are set simply apply the chat template to generate inputs + let inputs = infer.apply_chat_template(guideline, messages, None)?; + Ok((inputs, None, false)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ChatTemplateVersions; + use crate::HubTokenizerConfig; + use crate::TokenizerConfigToken; + use crate::Tool; + + use serde_json::json; + + #[test] + fn test_prepare_chat_input() { + // Mock Backend to avoid network requests + struct MockBackend; + + impl Backend for MockBackend { + fn schedule( + &self, + _request: crate::validation::ValidGenerateRequest, + ) -> Result< + tokio_stream::wrappers::UnboundedReceiverStream< + Result, + >, + InferError, + > { + unimplemented!("Never called in this test"); + } + fn health<'a, 'async_trait>( + &'a self, + _current_health: bool, + ) -> core::pin::Pin< + Box + core::marker::Send + 'async_trait>, + > + where + 'a: 'async_trait, + Self: 'async_trait, + { + unimplemented!("Never called in this test"); + } + } + + let backend = MockBackend {}; + + let mut tokenizer_config = HubTokenizerConfig::default(); + + // mock tokenizer config values + tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.chat_template = Some( + ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) + ); + + let infer = Infer::new( + backend, + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + 1, + tokenizer_config, + HubProcessorConfig::default(), + ); + let response_format = None; + let tools = Some(vec![Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "get_current_weather".to_string(), + description: Some("Get the current weather".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + }), + }, + }]); + let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; + let guideline = None; + let messages = vec![Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in New York?".to_string(), + ), + }]; + + let result = prepare_chat_input( + &infer, + response_format, + tools, + ToolChoice(None), + tool_prompt, + guideline, + messages, + ); + + assert!(result.is_ok()); + let (inputs, _grammar, using_tools) = result.unwrap(); + assert_eq!(using_tools, true); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + } } From 8398d4f436cffe142586484ab0623ead063e8803 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 19 Aug 2024 16:00:48 +0000 Subject: [PATCH 9/9] feat: add /v1/models endpoint --- router/src/lib.rs | 28 ++++++++++++++++++++++++++++ router/src/server.rs | 27 ++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index ce4f7c46..d874c38b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1240,6 +1240,34 @@ pub(crate) struct ErrorResponse { pub error_type: String, } +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelInfo { + #[schema(example = "gpt2")] + pub id: String, + #[schema(example = "model")] + pub object: String, + #[schema(example = 1686935002)] + pub created: u64, + #[schema(example = "openai")] + pub owned_by: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelsInfo { + #[schema(example = "list")] + pub object: String, + pub data: Vec, +} + +impl Default for ModelsInfo { + fn default() -> Self { + ModelsInfo { + object: "list".to_string(), + data: Vec::new(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index 8ebd1a33..88a2e9ce 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,8 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -116,6 +117,25 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( +get, +tag = "Text Generation Inference", +path = "/v1/models", +responses((status = 200, description = "Served model info", body = ModelInfo)) +)] +#[instrument] +async fn openai_get_model_info(info: Extension) -> Json { + Json(ModelsInfo { + data: vec![ModelInfo { + id: info.0.model_id.clone(), + object: "model".to_string(), + created: 0, // TODO: determine how to get this + owned_by: info.0.model_id.clone(), + }], + ..Default::default() + }) +} + #[utoipa::path( post, tag = "Text Generation Inference", @@ -2208,7 +2228,7 @@ async fn start( // Define base and health routes let mut base_routes = Router::new() - .route("/", post(compat_generate)) + .route("/", post(openai_get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) @@ -2246,7 +2266,8 @@ async fn start( .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + .route("/v1/models", get(openai_get_model_info)); // Conditional AWS Sagemaker route let aws_sagemaker_route = if messages_api_enabled {