From eab07f746c425ab441b68cd0ecc980ca6e981577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 24 Oct 2024 16:36:18 +0200 Subject: [PATCH] Add support for FP8 KV cache scales (#2628) * Add support for FP8 KV cache scales Since FP8 only has limited dynamic range, we can scale keys/values before storing them into the cache (and unscale them in attention). To avoid rescaling the cache as the absmax values change, good scales are usually determined per layer using calibration calibration data and stored in the checkpoint. This change adds support for for using key-value scales and loading them from checkpoints in the two most common formats: - Separate per-layer `k_scale` and `v_scale` scalars. - Per-layer `kv_scale` scalar (older format). Currently, scales are only used with an `float8_e4m3fn` cache. Besides adding support for key/value scales, the `fp8_quantize` function is also extended to support quantization with a kernel vendored from vLLM. This is slightly faster than the PyTorch implementation, but also scales in FP32, potentially improving accuracy. * Update FP8 KV cache test to use checkpoint with scales * `can_scale`: check that the attention is flashinfer --- flake.lock | 7 +- flake.nix | 2 +- .../test_flash_llama_fp8_kv_cache.json | 36 ++--- ...t_flash_llama_fp8_kv_cache_all_params.json | 70 +++++++-- .../test_flash_llama_fp8_kv_cache_load.json | 144 +++++++++--------- .../models/test_flash_llama_fp8_kv_cache.py | 8 +- server/poetry.lock | 24 +-- server/pyproject.toml | 8 +- .../layers/attention/__init__.py | 3 +- .../layers/attention/cuda.py | 14 +- .../layers/attention/flashinfer.py | 3 +- .../layers/attention/ipex.py | 5 +- .../layers/attention/kv_cache.py | 97 +++++++++++- .../layers/attention/rocm.py | 5 +- server/text_generation_server/layers/fp8.py | 17 +++ .../custom_modeling/flash_cohere_modeling.py | 11 +- .../custom_modeling/flash_dbrx_modeling.py | 11 +- .../flash_deepseek_v2_modeling.py | 14 +- .../custom_modeling/flash_gemma2_modeling.py | 11 +- .../custom_modeling/flash_gemma_modeling.py | 11 +- .../custom_modeling/flash_gpt2_modeling.py | 11 +- .../custom_modeling/flash_gptj_modeling.py | 11 +- .../custom_modeling/flash_llama_modeling.py | 16 +- .../custom_modeling/flash_mistral_modeling.py | 11 +- .../custom_modeling/flash_mixtral_modeling.py | 11 +- .../custom_modeling/flash_neox_modeling.py | 11 +- .../custom_modeling/flash_phi_modeling.py | 11 +- .../custom_modeling/flash_qwen2_modeling.py | 12 +- .../custom_modeling/flash_rw_modeling.py | 19 ++- .../flash_santacoder_modeling.py | 11 +- .../flash_starcoder2_modeling.py | 11 +- .../models/flash_causal_lm.py | 1 + .../text_generation_server/utils/weights.py | 4 +- 33 files changed, 486 insertions(+), 155 deletions(-) diff --git a/flake.lock b/flake.lock index aacdd30e4..76b4ca2fe 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1728381423, - "narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=", + "lastModified": 1729531056, + "narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e", + "rev": "a84a90281a17b15762873845c947e5c78f5a8dd1", "type": "github" }, "original": { "owner": "huggingface", + "ref": "marlin-kernels-0.3.0", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index f26a983ed..5c05bfae7 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json index c55dd593a..b82882c00 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json @@ -11,27 +11,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.1875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.93359375, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.875, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1796875, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -39,66 +39,66 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.109375, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.079956055, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.2763672, + "logprob": -0.028808594, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37548828, + "logprob": -0.013671875, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4628906, + "logprob": -0.69921875, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02885437, + "logprob": -0.0005874634, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.2565918, + "logprob": -0.026855469, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0063438416, + "logprob": -0.00020885468, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3056641, + "logprob": -0.17773438, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.6035156, + "id": 18065, + "logprob": -0.703125, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json index d06d6e566..8bce3e108 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "eos_token", - "generated_tokens": 3, + "finish_reason": "length", + "generated_tokens": 10, "prefill": [ { "id": 128000, @@ -11,22 +11,22 @@ }, { "id": 374, - "logprob": -22.96875, + "logprob": -18.0, "text": " is" }, { "id": 5655, - "logprob": -10.71875, + "logprob": -11.75, "text": " deep" }, { "id": 6975, - "logprob": -2.6992188, + "logprob": -2.0625, "text": " learning" }, { "id": 30, - "logprob": -4.8398438, + "logprob": -6.0, "text": "?" } ], @@ -34,24 +34,66 @@ "tokens": [ { "id": 720, - "logprob": -0.4411621, + "logprob": 0.0, "special": false, "text": " \n" }, { - "id": 220, - "logprob": -0.35864258, + "id": 34564, + "logprob": -0.11279297, "special": false, - "text": " " + "text": "Deep" }, { - "id": 128001, + "id": 6975, + "logprob": -0.16015625, + "special": false, + "text": " learning" + }, + { + "id": 320, + "logprob": -0.25195312, + "special": false, + "text": " (" + }, + { + "id": 16931, + "logprob": -1.703125, + "special": false, + "text": "DL" + }, + { + "id": 8, "logprob": 0.0, - "special": true, - "text": "<|end_of_text|>" + "special": false, + "text": ")" + }, + { + "id": 374, + "logprob": -1.140625, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 1207, + "logprob": -1.3125, + "special": false, + "text": " sub" + }, + { + "id": 2630, + "logprob": 0.0, + "special": false, + "text": "field" } ], "top_tokens": null }, - "generated_text": "What is deep learning? \n " + "generated_text": "What is deep learning? \nDeep learning (DL) is a subfield" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json index 46670819f..c7acee467 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json @@ -12,27 +12,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.1875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.93359375, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.875, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1796875, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -40,68 +40,68 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.109375, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.0047912598, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.025512695, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.012145996, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.72265625, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0005760193, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02722168, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00023651123, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.17285156, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.703125, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" }, { "details": { @@ -116,27 +116,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.21875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.95703125, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.9375, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1328125, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -144,68 +144,68 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.1796875, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.02758789, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.013366699, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.6953125, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0004863739, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02709961, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00022506714, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.19726562, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.77734375, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" }, { "details": { @@ -220,27 +220,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.21875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.95703125, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.9375, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1328125, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -248,68 +248,68 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.1796875, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.02758789, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.013366699, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.6953125, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0004863739, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02709961, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00022506714, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.19726562, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.77734375, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" }, { "details": { @@ -324,27 +324,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.21875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.95703125, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.9375, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1328125, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -352,67 +352,67 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.1796875, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.02758789, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.013366699, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.6953125, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0004863739, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02709961, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00022506714, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.19726562, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.77734375, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" } ] diff --git a/integration-tests/models/test_flash_llama_fp8_kv_cache.py b/integration-tests/models/test_flash_llama_fp8_kv_cache.py index 05e9f0dd9..ccd7f78fe 100644 --- a/integration-tests/models/test_flash_llama_fp8_kv_cache.py +++ b/integration-tests/models/test_flash_llama_fp8_kv_cache.py @@ -4,7 +4,9 @@ import pytest @pytest.fixture(scope="module") def flash_llama_fp8_kv_cache_handle(launcher): with launcher( - "meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2" + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + num_shard=2, + kv_cache_dtype="fp8_e4m3fn", ) as handle: yield handle @@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps assert ( response.generated_text - == " Deep learning is a subset of machine learning that is" + == " Deep learning is a subset of machine learning that involves" ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load( assert len(responses) == 4 assert ( responses[0].generated_text - == " Deep learning is a subset of machine learning that is" + == " Deep learning is a subset of machine learning that involves" ) assert all( [r.generated_text == responses[0].generated_text for r in responses] diff --git a/server/poetry.lock b/server/poetry.lock index 80fe72ba2..1293e8836 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1215,12 +1215,12 @@ files = [ [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"}, ] [package.dependencies] @@ -1228,16 +1228,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"}, ] [package.dependencies] @@ -1245,16 +1245,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"}, ] [package.dependencies] @@ -1262,16 +1262,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"}, ] [package.dependencies] @@ -1279,7 +1279,7 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mdurl" diff --git a/server/pyproject.toml b/server/pyproject.toml index 6ea4718d7..d08d0b8f4 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0" numpy = "^1.26" marlin-kernels = [ - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index b1d7b864a..ebe32042c 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -28,10 +28,11 @@ else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") # KVCache needs `reshape_and_cache`, so ensure that it is defined already. -from .kv_cache import KVCache +from .kv_cache import KVCache, get_kv_scales __all__ = [ "attention", + "get_kv_scales", "paged_attention", "SUPPORTS_WINDOWING", "KVCache", diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 08326c827..d705afb0b 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,5 +1,5 @@ import torch -from text_generation_server.layers.attention.kv_cache import KVCache +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ( ATTENTION, @@ -8,6 +8,7 @@ from text_generation_server.models.globals import ( from text_generation_server.layers.attention import Seqlen from typing import Optional + major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 _PARTITION_SIZE = 512 @@ -21,6 +22,8 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + *, + kv_scales: KVScales, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -46,6 +49,8 @@ def paged_attention( num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + can_scale = kv_cache.can_scale(kv_scales) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -60,6 +65,8 @@ def paged_attention( paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, + k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, + v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, ) elif ATTENTION == "flashdecoding": max_q = 1 @@ -205,6 +212,7 @@ def attention( key: torch.Tensor, value: torch.Tensor, kv_cache: KVCache, + kv_scales: KVScales, seqlen: Seqlen, block_tables: torch.Tensor, softmax_scale: float, @@ -212,6 +220,8 @@ def attention( causal: bool = True, softcap: Optional[float] = None, ): + can_scale = kv_cache.can_scale(kv_scales) + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flashinfer import ( prefill_with_paged_kv_state, @@ -228,6 +238,8 @@ def attention( logits_soft_cap=softcap, sm_scale=softmax_scale, window_left=window_size_left, + k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, + v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, ) # If we are using flashdecoding or paged, we always use flash-attn for diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index d603c6f5f..26a72d9be 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -204,6 +204,7 @@ def use_decode_state( num_kv_heads: int, head_size: int, page_size: int, + kv_cache_dtype: torch.dtype, dtype: torch.dtype, window_left: int, ): @@ -240,7 +241,7 @@ def use_decode_state( num_kv_heads=num_kv_heads, head_dim=head_size, page_size=page_size, - data_type=dtype, + data_type=kv_cache_dtype, q_data_type=dtype, window_left=window_left, ) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index e76bb1f42..677f3f564 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,6 +1,6 @@ import intel_extension_for_pytorch as ipex import torch -from text_generation_server.layers.attention.kv_cache import KVCache +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -14,6 +14,7 @@ def attention( key: torch.Tensor, value: torch.Tensor, kv_cache: KVCache, + kv_scales: KVScales, seqlen: Seqlen, block_tables: torch.Tensor, softmax_scale: float, @@ -55,6 +56,8 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + *, + kv_scales: KVScales, softcap: Optional[float] = None, ): if softcap is not None: diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index d64302c65..9d739da5e 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -1,8 +1,38 @@ from typing import Tuple +from dataclasses import dataclass, field +from loguru import logger import torch + +from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights + + +@dataclass +class KVScales: + """ + Key-value scales for FP8 KV cache. + + This data class stores key and value scales both as a GPU tensor and + as a GPU float. This inconvenience is necessary because some functions + (e.g. scaling kernels) take scales as a GPU tensor, whereas others + (e.g. flashinfer) take scales as a CPU scalar. + """ + + key_scale: torch.Tensor + value_scale: torch.Tensor + key_scale_cpu: float = field(init=False) + value_scale_cpu: float = field(init=False) + + def __post_init__(self): + if self.key_scale.numel() != 1 or self.value_scale.numel() != 1: + raise ValueError("Key and value scales must be scalar tensors.") + + self.key_scale_cpu = self.key_scale.item() + self.value_scale_cpu = self.value_scale.item() class KVCache: @@ -76,6 +106,33 @@ class KVCache: ), ) + def can_scale(self, kv_scales: KVScales) -> bool: + """Check if the cache can be scaled by the given scales.""" + if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: + return False + elif ( + self.dtype == torch.float8_e4m3fn + and ATTENTION == "flashinfer" + and SYSTEM == "cuda" + ): + log_once( + logger.info, + "Using FP8 KV cache scales", + ) + return True + else: + # We have scales, but not the correct FP8 cache type, so warn once. + log_once( + logger.info, + "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", + ) + return False + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache[0].dtype + @property def key(self): """Get the key cache.""" @@ -94,17 +151,33 @@ class KVCache: key: torch.Tensor, value: torch.Tensor, slots: torch.Tensor, + kv_scales: KVScales, ): """Store the key and value at the given slots.""" key_cache = self.kv_cache[0] value_cache = self.kv_cache[1] + if self.can_scale(kv_scales): + if kv_scales.key_scale_cpu != 1.0: + key = fp8_quantize( + key.float(), + scale=kv_scales.key_scale, + qdtype=self.dtype, + scalar=True, + )[0] + if kv_scales.value_scale_cpu != 1.0: + value = fp8_quantize( + value.float(), + scale=kv_scales.value_scale, + qdtype=self.dtype, + scalar=True, + )[0] + if ATTENTION in {"flashdecoding", "flashinfer"}: - # TODO: add scale key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) - if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: + if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so # put as raw data instead. key_cache = key_cache.view(torch.uint8) @@ -151,5 +224,23 @@ def paged_reshape_and_cache( ) else: raise NotImplementedError( - f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention" + f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported" ) + + +def get_kv_scales(weights: Weights, prefix: str) -> KVScales: + """Load KV cache scales.""" + + key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device) + value_scale = key_scale + if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor( + f"{prefix}.v_scale" + ): + key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float() + value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float() + elif weights.has_tensor(f"{prefix}.kv_scale"): + # Fall back to older more coarse-grained scale when available. + key_scale = weights.get_tensor(f"{prefix}.kv_scale").float() + value_scale = key_scale + + return KVScales(key_scale=key_scale, value_scale=value_scale) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 47bf5539c..ea11c2c26 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,7 +1,7 @@ import os from typing import Optional import torch -from text_generation_server.layers.attention.kv_cache import KVCache +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master @@ -36,6 +36,8 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + *, + kv_scales: KVScales, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -210,6 +212,7 @@ def attention( key: torch.Tensor, value: torch.Tensor, kv_cache: KVCache, + kv_scales: KVScales, seqlen: Seqlen, block_tables: torch.Tensor, softmax_scale: float, diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 18a40afa3..a58c7f7b2 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -26,6 +26,12 @@ def is_fbgemm_gpu_available(): return False +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + if is_fbgemm_gpu_available(): if SYSTEM == "cuda": major, _ = torch.cuda.get_device_capability() @@ -94,6 +100,17 @@ def fp8_quantize( ) return qweight, scale + if marlin_kernels is not None: + shape = weight.shape + qweight, scale = marlin_kernels.scaled_fp8_quant( + weight.reshape(-1, shape[-1]), + dtype=qdtype, + scale=scale, + scale_ub=scale_upper_bound, + ) + + return qweight.reshape(shape), scale + # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 4eee5c208..68719106f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, Seqlen, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -227,6 +228,7 @@ class FlashCohereAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: @@ -289,7 +291,12 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -299,6 +306,7 @@ class FlashCohereAttention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -313,6 +321,7 @@ class FlashCohereAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 4ee677417..f70bff4f8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,6 +20,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "ipex": @@ -288,6 +289,7 @@ class DbrxAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -328,7 +330,12 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -338,6 +345,7 @@ class DbrxAttention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -352,6 +360,7 @@ class DbrxAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 97b3ea967..906a83a41 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -34,6 +34,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, ) +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale @@ -230,6 +231,8 @@ class DeepseekV2Attention(torch.nn.Module): ), ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_a_layernorm = FastRMSNorm.load( prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps ) @@ -258,7 +261,7 @@ class DeepseekV2Attention(torch.nn.Module): cos: torch.Tensor, sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_cache: KVCache, block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, @@ -319,7 +322,12 @@ class DeepseekV2Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -329,6 +337,7 @@ class DeepseekV2Attention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -343,6 +352,7 @@ class DeepseekV2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) # Remove padding. diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index c962a2aff..ebf1b80eb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -39,6 +39,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -206,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -251,7 +253,12 @@ class FlashGemma2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -261,6 +268,7 @@ class FlashGemma2Attention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -278,6 +286,7 @@ class FlashGemma2Attention(torch.nn.Module): seqlen, max_s, softcap=self.softcap, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index b127f2843..ad3be80e5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -185,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -222,7 +224,12 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -232,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -247,6 +255,7 @@ class FlashGemmaAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 2d005734b..906b34c12 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -36,6 +36,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module): head_size=self.head_size, num_heads=self.num_heads, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 2eef1dedc..692f8ca31 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -24,6 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module): prefix=prefix, weights=weights, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 20841aeb7..b26dd4849 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,7 +27,10 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.layers.attention import KVCache +from text_generation_server.layers.attention import ( + KVCache, + get_kv_scales, +) from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -179,6 +182,8 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights, index) self.index = index + self.kv_scales = get_kv_scales(weights, f"{prefix}") + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -224,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -233,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module): query=query, key=kv[:, 0], value=kv[:, 1], + kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, block_tables=block_tables, @@ -248,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 7bad429c3..c66c732f2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,6 +26,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 712b7bc46..a45dd1e61 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding @@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 2ce69d8ea..2301b63cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module): head_size=self.head_size, hidden_size=self.hidden_size, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) @@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots) + kv_cache.store( + key=qkv[:, 1], + value=qkv[:, 2], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module): key=qkv[:, 1], value=qkv[:, 2], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 62d524c9b..7382a7cb9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -18,6 +18,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") # in llama the dense layer is called "o_proj" and has bias=False self.dense = TensorParallelRowLinear.load( @@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module): query=query, key=kv[:, 0], value=kv[:, 1], + kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, block_tables=block_tables, @@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 905dd98fc..ab2a177db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -16,6 +16,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 8085ff892..2dcd1bf30 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,6 +12,7 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( @@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) kv_cache.store( - key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots + key=kv[:, :, 0].contiguous(), + value=kv[:, :, 1].contiguous(), + slots=slots, + kv_scales=self.kv_scales, ) # Prefill @@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module): key=kv[:, :, 0], value=kv[:, :, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 52119b64c..ed053eb66 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,6 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, @@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_head_mapping = torch.zeros( self.num_heads, dtype=torch.int32, device=weights.device ) @@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots) + kv_cache.store( + key=key_value[:, 0], + value=key_value[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module): key=key_value[:, 0], value=key_value[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index fe339aee7..c793982d8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm, @@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b1270b449..b931671cc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2283,6 +2283,7 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + kv_cache_dtype=self.kv_cache_dtype, dtype=self.dtype, window_left=self.sliding_window, ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 548591e57..aae64acf3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -207,7 +207,9 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): + def get_tensor( + self, tensor_name: str, to_device: bool = True, to_dtype: bool = True + ) -> torch.Tensor: filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name)