diff --git a/Cargo.lock b/Cargo.lock index d298c379..aa5cb642 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", + "serde_json", ] [[package]] diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 0fb05a98..faa57c11 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -316,10 +316,15 @@ impl State { + self.speculate - 1; - match block_allocator - .allocate(tokens, entry.request.input_ids.clone()) - .await - { + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() + }; + + match block_allocator.allocate(tokens, input_ids).await { None => { // Entry is over budget // Add it back to the front diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index ef963532..5bac1a31 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -205,6 +205,7 @@ pub struct RadixTrie { /// call that a real time lookup would require. time: u64, } + impl Default for RadixTrie { fn default() -> Self { Self::new() diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 12966747..45301b63 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -757,7 +757,12 @@ class AsyncClient: continue payload = byte_payload.decode("utf-8") if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + payload_data = ( + payload.lstrip("data:").rstrip("\n").removeprefix(" ") + ) + if payload_data == "[DONE]": + break + json_payload = json.loads(payload_data) try: response = ChatCompletionChunk(**json_payload) yield response diff --git a/docs/openapi.json b/docs/openapi.json index df21e19d..fd64a3ab 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -924,7 +924,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { diff --git a/flake.lock b/flake.lock index d0c2adbc..b40f51b3 100644 --- a/flake.lock +++ b/flake.lock @@ -492,24 +492,6 @@ "type": "github" } }, - "flake-utils_7": { - "inputs": { - "systems": "systems_7" - }, - "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, "gitignore": { "inputs": { "nixpkgs": [ @@ -594,27 +576,6 @@ "type": "github" } }, - "nix-github-actions": { - "inputs": { - "nixpkgs": [ - "poetry2nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1703863825, - "narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=", - "owner": "nix-community", - "repo": "nix-github-actions", - "rev": "5163432afc817cf8bd1f031418d1869e4c9d5547", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "nix-github-actions", - "type": "github" - } - }, "nix-test-runner": { "flake": false, "locked": { @@ -753,31 +714,6 @@ "type": "github" } }, - "poetry2nix": { - "inputs": { - "flake-utils": "flake-utils_7", - "nix-github-actions": "nix-github-actions", - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ], - "systems": "systems_8", - "treefmt-nix": "treefmt-nix" - }, - "locked": { - "lastModified": 1723854676, - "narHash": "sha256-+BrHfNuXrqeE7PoV6xDaoh0joYiJkvTTCIV0fFR3THw=", - "owner": "nix-community", - "repo": "poetry2nix", - "rev": "d650118bce34c0238b9b54f23f7f173f9e4db867", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "poetry2nix", - "type": "github" - } - }, "pre-commit-hooks": { "inputs": { "flake-compat": [ @@ -887,7 +823,6 @@ "tgi-nix", "nixpkgs" ], - "poetry2nix": "poetry2nix", "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } @@ -900,11 +835,11 @@ ] }, "locked": { - "lastModified": 1723515680, - "narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=", + "lastModified": 1724206841, + "narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3", + "rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", "type": "github" }, "original": { @@ -1003,46 +938,17 @@ "type": "github" } }, - "systems_7": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - }, - "systems_8": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "id": "systems", - "type": "indirect" - } - }, "tgi-nix": { "inputs": { "flake-compat": "flake-compat_4", "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1723973328, - "narHash": "sha256-q5FmW4YFQcRb6fXHnrxL0uno6xcw9dcg+pFBbVM1xeQ=", + "lastModified": 1724270760, + "narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "d2038f36589a8a179834e5771ffd081620ba94c3", + "rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c", "type": "github" }, "original": { @@ -1050,27 +956,6 @@ "repo": "tgi-nix", "type": "github" } - }, - "treefmt-nix": { - "inputs": { - "nixpkgs": [ - "poetry2nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1719749022, - "narHash": "sha256-ddPKHcqaKCIFSFc/cvxS14goUhCOAwsM1PbMr0ZtHMg=", - "owner": "numtide", - "repo": "treefmt-nix", - "rev": "8df5ff62195d4e67e2264df0b7f5e8c9995fd0bd", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "treefmt-nix", - "type": "github" - } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 6bfe74e8..83feb26a 100644 --- a/flake.nix +++ b/flake.nix @@ -8,10 +8,6 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; - poetry2nix = { - url = "github:nix-community/poetry2nix"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; - }; rust-overlay = { url = "github:oxalica/rust-overlay"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; @@ -26,7 +22,6 @@ flake-utils, rust-overlay, tgi-nix, - poetry2nix, }: flake-utils.lib.eachDefaultSystem ( system: @@ -47,14 +42,32 @@ tgi-nix.overlay ]; }; - inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; - text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; + benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override { + inherit crateOverrides; + }; + launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override { + inherit crateOverrides; + }; + router = cargoNix.workspaceMembers.text-generation-router-v3.build.override { + inherit crateOverrides; + }; + server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; }; in { - devShells.default = - with pkgs; - mkShell { + devShells = with pkgs; rec { + default = pure; + + pure = mkShell { + buildInputs = [ + benchmark + launcher + router + server + ]; + }; + + impure = mkShell { buildInputs = [ openssl.dev @@ -65,42 +78,16 @@ "rust-src" ]; }) + protobuf ] ++ (with python3.pkgs; [ venvShellHook pip - - causal-conv1d - click - einops - exllamav2 - fbgemm-gpu - flashinfer - flash-attn - flash-attn-layer-norm - flash-attn-rotary - grpc-interceptor - grpcio-reflection - grpcio-status - grpcio-tools - hf-transfer - loguru - mamba-ssm - marlin-kernels - opentelemetry-api - opentelemetry-exporter-otlp - opentelemetry-instrumentation-grpc - opentelemetry-semantic-conventions - peft - tokenizers - torch - transformers - vllm - - (cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; }) - (cargoNix.workspaceMembers.text-generation-router-v3.build.override { inherit crateOverrides; }) + ipdb ]); + inputsFrom = [ server ]; + venvDir = "./.venv"; postVenv = '' @@ -108,8 +95,21 @@ ''; postShellHook = '' unset SOURCE_DATE_EPOCH + export PATH=$PATH:~/.cargo/bin ''; }; + }; + + packages.default = pkgs.writeShellApplication { + name = "text-generation-inference"; + runtimeInputs = [ + server + router + ]; + text = '' + ${launcher}/bin/text-generation-launcher "$@" + ''; + }; } ); } diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 15af1cad..a8a77cd2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -257,7 +257,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator): class LauncherHandle: def __init__(self, port: int): - self.client = AsyncClient(f"http://localhost:{port}") + self.client = AsyncClient(f"http://localhost:{port}", timeout=30) def _inner_health(self): raise NotImplementedError diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index f831990a..9855cfda 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -36,6 +36,7 @@ tools = [ }, }, "required": ["location", "format"], + "additionalProperties": False, }, }, }, @@ -62,13 +63,13 @@ tools = [ }, }, "required": ["location", "format", "num_days"], + "additionalProperties": False, }, }, }, ] -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna max_tokens=100, seed=1, tools=tools, - presence_penalty=-1.1, + temperature=0.0, messages=[ { "role": "system", @@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="auto", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( @@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 38 + assert count == 48 assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_insufficient_information( @@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information( ): responses = await flash_llama_grammar_tools.chat( max_tokens=100, - seed=8, + seed=24, tools=tools, tool_choice="auto", messages=[ { "role": "system", - "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", }, { "role": "user", @@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( ) assert responses.choices[0].message.content is None - assert responses.choices[0].message.tool_calls == [ - { - "function": { - "arguments": { - "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." - }, - "description": None, - "name": "notify_error", - }, - "id": 0, - "type": "function", - } - ] - + assert ( + responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" + ) assert responses == response_snapshot diff --git a/nix/crate-overrides.nix b/nix/crate-overrides.nix index 343b3b25..cbf366af 100644 --- a/nix/crate-overrides.nix +++ b/nix/crate-overrides.nix @@ -20,34 +20,52 @@ defaultCrateOverrides rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; grpc-metadata = attrs: { - src = - filter { - root = ../backends/grpc-metadata; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + src = filter { + root = ../backends/grpc-metadata; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; - text-generation-launcer = attrs: { - src = - filter { - root = ../launcher; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + text-generation-benchmark = attrs: { + src = filter { + root = ../benchmark; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-client = attrs: { + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "backends/client") (matchExt "rs")) + (and (inDirectory "proto") (matchExt "proto")) + ]; + }; + postPatch = "cd backends/client"; + buildInputs = [ protobuf ]; + }; + text-generation-launcher = attrs: { + src = filter { + root = ../launcher; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; text-generation-router = attrs: { - src = - filter { - root = ../router; - include = with filter; [ - isDirectory - (matchExt "rs") - ]; - }; + src = filter { + root = ../router; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; }; text-generation-router-v3 = attrs: { # We need to do the src/source root dance so that the build diff --git a/nix/server.nix b/nix/server.nix new file mode 100644 index 00000000..4e0fdaa1 --- /dev/null +++ b/nix/server.nix @@ -0,0 +1,109 @@ +{ + nix-filter, + buildPythonPackage, + poetry-core, + mypy-protobuf, + awq-inference-engine, + causal-conv1d, + eetq, + einops, + exllamav2, + fbgemm-gpu, + flashinfer, + flash-attn, + flash-attn-layer-norm, + flash-attn-rotary, + grpc-interceptor, + grpcio-reflection, + grpcio-status, + grpcio-tools, + hf-transfer, + loguru, + mamba-ssm, + marlin-kernels, + opentelemetry-api, + opentelemetry-exporter-otlp, + opentelemetry-instrumentation-grpc, + opentelemetry-semantic-conventions, + peft, + safetensors, + tokenizers, + sentencepiece, + transformers, + typer, + vllm, +}: + +let + filter = nix-filter.lib; +in +buildPythonPackage { + name = "text-generation-server"; + + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi"))) + "server/pyproject.toml" + (and (inDirectory "proto/v3") (matchExt "proto")) + ]; + }; + + pyproject = true; + + build-system = [ poetry-core ]; + + nativeBuildInputs = [ mypy-protobuf ]; + + pythonRelaxDeps = [ + "einops" + "huggingface-hub" + "loguru" + "opentelemetry-instrumentation-grpc" + "sentencepiece" + "typer" + ]; + + pythonRemoveDeps = [ "scipy" ]; + + dependencies = [ + awq-inference-engine + eetq + causal-conv1d + einops + exllamav2 + fbgemm-gpu + flashinfer + flash-attn + flash-attn-layer-norm + flash-attn-rotary + grpc-interceptor + grpcio-reflection + grpcio-status + grpcio-tools + hf-transfer + loguru + mamba-ssm + marlin-kernels + opentelemetry-api + opentelemetry-exporter-otlp + opentelemetry-instrumentation-grpc + opentelemetry-semantic-conventions + peft + safetensors + sentencepiece + tokenizers + transformers + typer + vllm + ]; + + prePatch = '' + python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \ + --grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto + find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch server/text_generation_server/pb/__init__.py + cd server + ''; +} diff --git a/router/Cargo.toml b/router/Cargo.toml index 7773e212..45acab8e 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } -minijinja = { version = "2.0.2" } +minijinja = { version = "2.0.2", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" regex = "1.10.3" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index a8537818..bfa9421c 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,9 +1,7 @@ use std::collections::HashSet; use crate::infer::InferError; -use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, -}; +use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -32,6 +30,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + tracing::debug!("Loading template: {:#?}", template_str); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) @@ -42,6 +41,7 @@ impl ChatTemplate { let variables = template.undeclared_variables(true); // check if the `tools` variable is used in the template let use_default_tool_template = !variables.contains("tools"); + tracing::debug!("Use default tool template: {}", use_default_tool_template); Self { template, @@ -56,25 +56,36 @@ impl ChatTemplate { &self, guideline: Option<&str>, mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - // check if guideline is expected but not provided if self.variables.contains("guideline") && guideline.is_none() { return Err(InferError::MissingTemplateVariable("guideline".to_string())); } + let tools = match tools_and_prompt { + Some((tools, tool_prompt)) => { + // check if the `tools` variable is used in the template + // if not, we need to append the tools to the last message + let text = if self.use_default_tool_template { + match serde_json::to_string(&tools) { + Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Err(e) => return Err(InferError::ToolError(e.to_string())), + } + } else { + // if the `tools` variable is used in the template, we just append the tool_prompt + format!("\n---\n{}", tool_prompt) + }; + if let Some(last_message) = messages.last_mut() { + last_message.content.push(MessageChunk::Text { text }); + } + Some(tools) + } + None => None, + }; + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + self.template .render(ChatTemplateInputs { guideline, @@ -82,8 +93,7 @@ impl ChatTemplate { bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, - tools: None, - tools_prompt: None, + tools, }) .map_err(InferError::TemplateError) } @@ -95,7 +105,7 @@ mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, }; use minijinja::Environment; @@ -854,11 +864,12 @@ mod tests { content: MessageContent::SingleText("Just testing".to_string()), }, ]; - let tools = serde_json::json!("[]"); + let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); + let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); - let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); - let result = ct.apply(None, msgs, Some(grammer_with_prompt)); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + let tools_and_prompt = Some((tools, tool_prompt)); + let result = ct.apply(None, msgs, tools_and_prompt); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index c9354d9a..81c0d38f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -3,7 +3,7 @@ mod chat_template; pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; -use crate::GrammarType; +use crate::Tool; use crate::{ ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, Message, PrefillToken, Token, @@ -140,12 +140,12 @@ impl Infer { &self, guideline: Option, messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 05027f30..4fe15720 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,5 +1,8 @@ use crate::infer::InferError; -use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use crate::{ + FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, + ToolType, +}; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -16,17 +19,38 @@ impl ToolGrammar { } pub fn apply( - tools: Option>, + tools: Vec, tool_choice: ToolChoice, - ) -> Result, InferError> { + ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None - let tools = match tools { - Some(tools) if !tools.is_empty() => tools, - _ => return Ok(None), - }; + if tools.is_empty() { + return Ok((tools, None)); + } let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + let mut tools = tools.clone(); + + // add the notify_error function to the tools + let notify_error = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "notify_error".to_string(), + description: Some("Notify an error or issue".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "error": { + "type": "string", + "description": "The error or issue to notify" + } + }, + "required": ["error"] + }), + }, + }; + tools.push(notify_error); + // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { ToolType::FunctionName(name) => { @@ -35,87 +59,57 @@ impl ToolGrammar { ToolType::Function { function } => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools, - ToolType::NoTool => return Ok(None), + ToolType::OneOf => tools.clone(), + ToolType::NoTool => return Ok((tools, None)), }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - let functions: HashMap = tools_to_use .iter() .map(|tool| { let func = tool.function.clone(); - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; + let mut params = Map::new(); - // Insert the function's description at the top level, outside of properties params.insert( "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), + Value::String(func.description.unwrap_or_default()), ); - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); + let mut properties = Map::new(); + let mut required = vec![Value::String("_name".to_string())]; - // Insert the constant for the function name inside 'properties' properties.insert( "_name".to_string(), json!({ "type": "string", "const": func.name.clone(), - // "description": "The name of the function" }), ); - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); + if let Value::Object(args) = func.arguments { + if let Some(Value::Object(props)) = args.get("properties") { + properties.extend(props.clone()); + } + if let Some(Value::Array(reqs)) = args.get("required") { + required.extend(reqs.clone()); + } + params.insert( + "additionalProperties".to_string(), + Value::Bool( + args.get("additionalProperties").and_then(|v| v.as_str()) + == Some("true"), + ), + ); } + params.insert("properties".to_string(), Value::Object(properties)); + params.insert("required".to_string(), Value::Array(required)); + (func.name, Value::Object(params)) }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) .collect(); - let tools = Tools { + let tool_schema = JsonSchemaTool { functions_map: FunctionsMap { functions }, properties: Properties { function: tools_to_use @@ -123,13 +117,10 @@ impl ToolGrammar { .map(|tool| FunctionRef { ref_path: format!("#/$functions/{}", tool.function.name.clone()), }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) .collect(), }, }; - Ok(Some(tools)) + Ok((tools, Some(tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 6cefe5e2..d874c38b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -840,10 +840,10 @@ pub(crate) struct ChatRequest { pub tools: Option>, /// A prompt to be appended before the tools - #[serde(default = "default_tool_prompt")] + #[serde(default)] #[schema( nullable = true, - example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" + example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." )] pub tool_prompt: Option, @@ -865,10 +865,8 @@ pub(crate) struct ChatRequest { pub guideline: Option, } -fn default_tool_prompt() -> Option { - Some( - "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), - ) +pub fn default_tool_prompt() -> String { + "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] @@ -910,7 +908,7 @@ impl From for ToolChoice { } #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] -pub struct Tools { +pub struct JsonSchemaTool { #[serde(flatten)] functions_map: FunctionsMap, properties: Properties, @@ -968,8 +966,7 @@ pub(crate) struct ChatTemplateInputs<'a> { bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, - tools: Option<&'a str>, - tools_prompt: Option<&'a str>, + tools: Option>, guideline: Option<&'a str>, } diff --git a/router/src/server.rs b/router/src/server.rs index 31242c2a..88a2e9ce 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::ChatTokenizeResponse; +use crate::{default_tool_prompt, ChatTokenizeResponse}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -166,7 +166,7 @@ async fn get_chat_tokenize( } = req; let tool_prompt = tool_prompt.unwrap_or_default(); - let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + let (inputs, _grammar, _using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1178,14 +1178,16 @@ async fn chat_completions( let repetition_penalty = presence_penalty.map(|x| x + 2.0); let max_new_tokens = max_tokens.or(Some(100)); let logprobs = logprobs.unwrap_or(false); - let tool_prompt = tool_prompt.unwrap_or_default(); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); let stop = stop.unwrap_or_default(); // enable greedy only when temperature is 0 let (do_sample, temperature) = match temperature { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - let (inputs, grammar, tool_grammar) = prepare_chat_input( + let (inputs, grammar, using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1241,7 +1243,7 @@ async fn chat_completions( }); // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if tool_grammar.is_some() { + let (content, tool_calls) = if using_tools { (None, Some(vec![stream_token.token.text])) } else { let content = if !stream_token.token.special { @@ -1295,7 +1297,7 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let (tool_calls, output) = if tool_grammar.is_some() { + let (tool_calls, output) = if using_tools { let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| { InferError::ToolError(format!( @@ -2560,7 +2562,7 @@ fn create_post_processor( Ok(post_processor) } -type PreparedInput = (String, Option, Option); +type PreparedInput = (String, Option, bool); fn prepare_chat_input( infer: &Infer, @@ -2577,19 +2579,139 @@ fn prepare_chat_input( )); } + // when response_format is set, tools are not included when applying the chat template to generate inputs if let Some(format) = response_format { let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), None)); + return Ok((inputs, Some(format), false)); } - // if tools are set, apply the tool grammar and then the chat template - let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; - let grammar = tool_grammar - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; - Ok((inputs, grammar, tool_grammar)) + // when no response_format is set and tools are included, apply the chat template with the tools + // to generate inputs + if let Some(tools) = tools { + let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; + + let grammar = tool_schema + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt.into())), + )?; + return Ok((inputs, grammar, tool_schema.is_some())); + } + + // if no response_format or tools are set simply apply the chat template to generate inputs + let inputs = infer.apply_chat_template(guideline, messages, None)?; + Ok((inputs, None, false)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ChatTemplateVersions; + use crate::HubTokenizerConfig; + use crate::TokenizerConfigToken; + use crate::Tool; + + use serde_json::json; + + #[test] + fn test_prepare_chat_input() { + // Mock Backend to avoid network requests + struct MockBackend; + + impl Backend for MockBackend { + fn schedule( + &self, + _request: crate::validation::ValidGenerateRequest, + ) -> Result< + tokio_stream::wrappers::UnboundedReceiverStream< + Result, + >, + InferError, + > { + unimplemented!("Never called in this test"); + } + fn health<'a, 'async_trait>( + &'a self, + _current_health: bool, + ) -> core::pin::Pin< + Box + core::marker::Send + 'async_trait>, + > + where + 'a: 'async_trait, + Self: 'async_trait, + { + unimplemented!("Never called in this test"); + } + } + + let backend = MockBackend {}; + + let mut tokenizer_config = HubTokenizerConfig::default(); + + // mock tokenizer config values + tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.chat_template = Some( + ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) + ); + + let infer = Infer::new( + backend, + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + 1, + tokenizer_config, + HubProcessorConfig::default(), + ); + let response_format = None; + let tools = Some(vec![Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "get_current_weather".to_string(), + description: Some("Get the current weather".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + }), + }, + }]); + let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; + let guideline = None; + let messages = vec![Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in New York?".to_string(), + ), + }]; + + let result = prepare_chat_input( + &infer, + response_format, + tools, + ToolChoice(None), + tool_prompt, + guideline, + messages, + ); + + assert!(result.is_ok()); + let (inputs, _grammar, using_tools) = result.unwrap(); + assert_eq!(using_tools, true); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + } } diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index f9b1715e..56fc5319 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -6,7 +6,12 @@ from .common import Seqlen if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": - from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .cuda import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "ipex": diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 8703eb94..b3b7ea4f 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -76,7 +76,7 @@ def paged_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import decode_state + from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( query.contiguous(), @@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2 if ATTENTION == "flashinfer": def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -231,14 +233,15 @@ if ATTENTION == "flashinfer": causal=True, softcap=0.0, ): - from text_generation_server.layers.attention.flash_infer import prefill_state + assert window_size_left == -1, "Windowing is not supported with flash infer" + from text_generation_server.layers.attention.flashinfer import ( + prefill_with_paged_kv_state, + ) - return prefill_state.get().forward( - q, - k, - v, + return prefill_with_paged_kv_state.get().forward( + q.contiguous(), causal=causal, - window_left=window_size_left, + paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, ) @@ -249,6 +252,8 @@ elif V2: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -289,6 +294,8 @@ else: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flashinfer.py similarity index 65% rename from server/text_generation_server/layers/attention/flash_infer.py rename to server/text_generation_server/layers/attention/flashinfer.py index 56b53b2c..e1ef62c5 100644 --- a/server/text_generation_server/layers/attention/flash_infer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con "prefill_state" ) +prefill_with_paged_kv_state: ContextVar[ + flashinfer.BatchPrefillWithPagedKVCacheWrapper +] = ContextVar("prefill_with_paged_kv_state") + decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( "decode_state" ) @@ -24,6 +28,78 @@ def get_workspace(device): return workspace +def create_prefill_with_paged_kv_state( + *, + device: torch.device, +): + """Create a prefill state that uses the KV cache.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_with_paged_kv_state( + *, + state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, + block_tables: torch.Tensor, + cu_seqlens: torch.Tensor, + input_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + if page_size == 1: + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + else: + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = prefill_with_paged_kv_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + paged_kv_indptr=indptr, + paged_kv_indices=block_tables, + paged_kv_last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + page_size=page_size, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_with_paged_kv_state.reset(token) + + def create_prefill_state( *, device: torch.device, diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 7579ccdb..139c4dc2 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module): ) def forward(self, x): + if not self.heads: + return None speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) return speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e02a31d9..1eb8c6c3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d3d1d1ef..fc0dca5b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0905d3c2..b25becd5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 54d212e6..faf0f325 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 178efadb..33738a59 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a19cff8c..d30b5a0a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9ea19a87..3253d2dc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dda53ff3..5a150267 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 85431c6c..ad426ffe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 67237d5c..b684e035 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0], qkv[:, 1], qkv[:, 2], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index d10efb41..e08a2aad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module): config=config.vision_config, weights=weights, ) + self.post_vision_tower_layernorm = nn.LayerNorm.load( + prefix="vision_tower.vision_model.post_layernorm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) self.multi_modal_projector = TensorParallelColumnLinear.load( config, @@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): if pixel_values is not None: pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) image_outputs = self.vision_tower(pixel_values) - image_features = self.multi_modal_projector(image_outputs.last_hidden_state) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) # mask where image or padding tokens mask = input_ids == self.config.image_token_index diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9e18348..efe27c13 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 865cc85d..879b8abd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 10f995a3..c72a9b90 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, @@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module): query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c2676782..109304be 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module): query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index e562eb89..200d4ef0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 480d0f9f..95ac9ede 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module): inputs_embeds, attention_mask: Optional[torch.Tensor] = None, ): - hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): hidden_states, _ = encoder_layer( @@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) - self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, ): - r""" - Returns: - - """ if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module): inputs_embeds=hidden_states, ) last_hidden_state = encoder_outputs - post_last_hidden_state = self.post_layernorm(last_hidden_state) return BaseModelOutputWithPooling( - last_hidden_state=post_last_hidden_state, + last_hidden_state=last_hidden_state, # pooler_output=pooled_output, # hidden_states=encoder_outputs, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5e2fd20a..dd4203e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,6 +43,7 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, + PREFIX_CACHING, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor max_seqlen: int @@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + # Prefixes + prefix_ids: List[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] all_input_ids = [] + prefix_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True @@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 + cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] + prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch): ): tokenized_input = tokenized_input[1:] + orig_input_length = len(tokenized_input) + + if PREFIX_CACHING: + prefix_len = r.prefix_len + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 + else: + prefix_len = 0 + + prefix_ids.append(tokenized_input[:prefix_len]) + tokenized_input = tokenized_input[prefix_len:] + input_length = len(tokenized_input) input_lengths.append(input_length) @@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + request_position_ids = torch.arange( + prefix_len, orig_input_length, dtype=torch.int32 + ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs @@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch): # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to blocks. + block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to slots. We don't need slots for the + # cached prefix (if present). + slot_tokens = input_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] @@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots + request_slots = r.slots[ + prefix_len: #: orig_input_length + max_new_tokens + speculative_length + ] block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) + + slots.extend(request_slots) + prefix_lens.append(prefix_len) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) + start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length - cumulative_max_length += total_tokens + cumulative_slot_tokens += slot_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch): ) slots = torch.tensor(slots, dtype=torch.int64, device=device) + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] + prefix_lens = [] prefix_offsets = [] read_offsets = [] @@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] + prefix_len = self.prefix_lens[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) + prefix_ids.append(self.prefix_ids[idx]) input_lengths.append(request_input_length) + prefix_lens.append(prefix_len) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] + prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( @@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) + prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) @@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] + prefix_lens = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] prefix_offsets = [] @@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) + prefix_lens.extend(batch.prefix_lens) all_input_ids.extend(batch.all_input_ids) + prefix_ids.extend(batch.prefix_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) @@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch): slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, @@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -970,19 +1028,22 @@ class FlashCausalLM(Model): self.kv_cache = [] if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_prefill_state, create_decode_state, + create_prefill_with_paged_kv_state, ) self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) - if not CUDA_GRAPHS: - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) super().__init__( model_id=model_id, @@ -1074,12 +1135,23 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) + input_lengths = [max_s] * bs + prefix_lengths = [0] * bs + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) + prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lengths, + ) self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1087,14 +1159,14 @@ class FlashCausalLM(Model): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "input_lengths": input_lengths, + "input_lengths": input_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, ) @@ -1104,7 +1176,7 @@ class FlashCausalLM(Model): last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) state = create_decode_state_cuda_graphs( device=input_ids.device, - block_tables=block_tables.view(-1), + block_tables=block_tables, block_tables_ptr=block_tables_ptr, last_page_len=last_page_len, num_heads=self.num_heads, @@ -1120,7 +1192,10 @@ class FlashCausalLM(Model): block_tables=block_tables, cu_seqlen_prefill=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, state=state, + prefix_lens=prefix_lengths, + prefix_lens_tensor=prefix_lengths_tensor, ): self.model.forward( input_ids=input_ids, @@ -1138,7 +1213,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1146,7 +1221,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_tensor, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1334,6 +1409,9 @@ class FlashCausalLM(Model): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -1354,6 +1432,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1372,10 +1451,20 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths=input_lengths, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, ): input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( @@ -1399,20 +1488,32 @@ class FlashCausalLM(Model): # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) - state = cuda_graph.get("state") with self._forward_context( - block_tables=block_tables, + block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - input_lengths=input_lengths, - state=state, + input_lengths=batch.input_lengths, + input_lengths_tensor=cuda_graph["input_lengths"], + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + state=cuda_graph.get("state"), ): # Replay the graph cuda_graph["graph"].replay() @@ -1610,6 +1711,7 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, + batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -1627,6 +1729,7 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, + prefix_ids, do_sample, seed, top_n_tokens, @@ -1701,18 +1804,18 @@ class FlashCausalLM(Model): out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] + request_prefill_logprobs = ( + [float("nan")] * (len(prefix_ids) + 1) + ) + prefill_logprobs[out_start_index : out_end_index - 1] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, + prefix_ids + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = Tokens( - prefill_token_ids, + prefix_ids + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], @@ -1794,33 +1897,68 @@ class FlashCausalLM(Model): *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths: torch.Tensor, + input_lengths: List[int], + input_lengths_tensor: torch.Tensor, + prefix_lens: List[int], + prefix_lens_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( use_decode_state, - use_prefill_state, + use_prefill_with_paged_kv_state, ) + # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + if cu_seqlen_prefill is not None: - return use_prefill_state( - state=state if state is not None else self.prefill_state, + return use_prefill_with_paged_kv_state( + state=( + state if state is not None else self.prefill_with_paged_kv_state + ), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - ) - else: - assert input_lengths is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths, - block_tables=block_tables.view(-1), + input_lengths=input_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, ) + else: + assert input_lengths_tensor is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths_tensor, + block_tables=block_tables, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) + + +def block_tables_to_ragged( + *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(prefix_lens) + + total_len = sum(input_lengths) + sum(prefix_lens) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + offset = 0 + for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): + seq_len = prefix_len + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index abc35421..d5133f5e 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,8 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) -log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") - +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} +log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( @@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer": else: BLOCK_SIZE = 16 - cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7de54aa4..2ed1a119 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + block_tables_to_ragged, ) +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM): trust_remote_code: bool, **kwargs, ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") if processor_kwargs is None: processor_kwargs = {} self.processor = processor_class.from_pretrained( @@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -349,43 +357,68 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) # Replay the graph cuda_graph["graph"].replay()