From a9c7d2e3b6b73c49955e7d18ad09ff5b99341b9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 9 Jan 2025 16:25:00 +0100 Subject: [PATCH] Basic flashinfer 0.2 support (#2862) * Basic flashinfer 0.2 support This change does not use any of the new features yet, but makes some small compatibility changes. * Update to flashinfer 0.2.0.post1 * flashinfer: remove `contiguous` calls * Fix flashinfer install * flashinfer: fixup kv cache dtype * Fix some annoying perturbations * More output changes --- flake.lock | 7 +- flake.nix | 2 +- ...ompressed_tensors_w8a8_int_all_params.json | 6 +- ...rs_w8a8_int_dynamic_weight_all_params.json | 36 +++--- ..._tensors_w8a8_int_dynamic_weight_load.json | 80 ++++++------ ...t_compressed_tensors_wna16_all_params.json | 2 +- .../test_flash_gemma_gptq_all_params.json | 70 +++++------ .../test_flash_starcoder2_default_params.json | 114 +++++++++--------- ...pressed_tensors_w8a8_int_dynamic_weight.py | 2 +- server/Makefile-flashinfer | 5 +- .../layers/attention/cuda.py | 6 +- .../layers/attention/flashinfer.py | 49 ++------ .../models/flash_causal_lm.py | 5 +- 13 files changed, 177 insertions(+), 207 deletions(-) diff --git a/flake.lock b/flake.lock index ec87d569..44802e18 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1732218602, - "narHash": "sha256-BElslL34KjOJCFMPkNtilOz6S/7iY7Vd72FNbRRWKDY=", + "lastModified": 1736179589, + "narHash": "sha256-/zZCSieBJncVXqOFbvbSov76g2eWAxVxEJNNA6SmQKc=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "f79638ac4e420e661321261744e745a3a747e182", + "rev": "fc7ff53b2cd5c984ad1434f20c271e3b7600d1c4", "type": "github" }, "original": { "owner": "huggingface", + "ref": "flashinfer-v0.2", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index 83cedfa6..a302db3e 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/flashinfer-v0.2"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json index 7d35e8f9..771708eb 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json @@ -32,7 +32,7 @@ }, { "id": 1101, - "logprob": -1.0947266, + "logprob": -1.0136719, "special": false, "text": " also" }, @@ -56,13 +56,13 @@ }, { "id": 4009, - "logprob": -0.15563965, + "logprob": -0.21923828, "special": false, "text": " network" }, { "id": 477, - "logprob": -1.4003906, + "logprob": -1.4824219, "special": false, "text": " or" } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json index 0db48f3e..6b3f5092 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json @@ -8,7 +8,7 @@ "tokens": [ { "id": 1939, - "logprob": -2.2675781, + "logprob": -2.2460938, "special": false, "text": "?\n\n" }, @@ -20,13 +20,13 @@ }, { "id": 20909, - "logprob": -0.37695312, + "logprob": -0.48608398, "special": false, "text": " Learning" }, { "id": 4102, - "logprob": -1.9316406, + "logprob": -2.265625, "special": false, "text": " " }, @@ -38,25 +38,13 @@ }, { "id": 458, - "logprob": -0.80859375, + "logprob": -0.6328125, "special": false, "text": " an" }, - { - "id": 3082, - "logprob": -1.4541016, - "special": false, - "text": " area" - }, - { - "id": 315, - "logprob": 0.0, - "special": false, - "text": " of" - }, { "id": 20443, - "logprob": -0.5136719, + "logprob": -0.1796875, "special": false, "text": " artificial" }, @@ -65,9 +53,21 @@ "logprob": 0.0, "special": false, "text": " intelligence" + }, + { + "id": 320, + "logprob": -0.37695312, + "special": false, + "text": " (" + }, + { + "id": 15469, + "logprob": 0.0, + "special": false, + "text": "AI" } ], "top_tokens": null }, - "generated_text": "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" + "generated_text": "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json index abcaf876..1fa4e33a 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json @@ -9,61 +9,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } @@ -82,61 +82,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } @@ -155,61 +155,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } @@ -228,61 +228,61 @@ "tokens": [ { "id": 18183, - "logprob": -1.6669922, + "logprob": -1.4912109, "special": false, "text": " Deep" }, { "id": 6832, - "logprob": -0.08959961, + "logprob": -0.075683594, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.14685059, + "logprob": -0.12408447, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.125, + "logprob": -0.12768555, "special": false, "text": " a" }, { "id": 25993, - "logprob": -0.81640625, + "logprob": -0.82128906, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.0013418198, + "logprob": -0.0012636185, "special": false, "text": " of" }, { "id": 5662, - "logprob": -0.16259766, + "logprob": -0.12878418, "special": false, "text": " machine" }, { "id": 6832, - "logprob": -0.0016393661, + "logprob": -0.0015888214, "special": false, "text": " learning" }, { "id": 429, - "logprob": -0.4477539, + "logprob": -0.49194336, "special": false, "text": " that" }, { "id": 5711, - "logprob": -1.2802734, + "logprob": -1.2626953, "special": false, "text": " uses" } diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json index 08c63e79..29709676 100644 --- a/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json @@ -44,7 +44,7 @@ }, { "id": 38397, - "logprob": -0.12695312, + "logprob": 0.0, "special": false, "text": " subset" }, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json index 6306f75e..0f54bbe8 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -14,60 +14,60 @@ }, { "id": 573, - "logprob": -0.18493652, + "logprob": -0.19030762, "special": false, "text": " the" }, { "id": 16819, - "logprob": -1.4804688, + "logprob": -1.4863281, "special": false, "text": " detection" }, { "id": 576, - "logprob": -0.7011719, + "logprob": -0.7089844, + "special": false, + "text": " of" + }, + { + "id": 573, + "logprob": -2.0410156, + "special": false, + "text": " the" + }, + { + "id": 8566, + "logprob": 0.0, + "special": false, + "text": " presence" + }, + { + "id": 689, + "logprob": -0.16491699, + "special": false, + "text": " or" + }, + { + "id": 14862, + "logprob": 0.0, + "special": false, + "text": " absence" + }, + { + "id": 576, + "logprob": -0.9970703, "special": false, "text": " of" }, { "id": 671, - "logprob": -2.1738281, + "logprob": -0.5292969, "special": false, "text": " an" - }, - { - "id": 24646, - "logprob": -3.0449219, - "special": false, - "text": " RNA" - }, - { - "id": 12369, - "logprob": -0.19299316, - "special": false, - "text": " virus" - }, - { - "id": 575, - "logprob": -0.10632324, - "special": false, - "text": " in" - }, - { - "id": 6022, - "logprob": -0.98095703, - "special": false, - "text": " patients" - }, - { - "id": 1064, - "logprob": -1.3095703, - "special": false, - "text": " who" } ], "top_tokens": null }, - "generated_text": "Test request for the detection of an RNA virus in patients who" + "generated_text": "Test request for the detection of the presence or absence of an" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json index 914e59c0..6674cf50 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -8,7 +8,7 @@ "tokens": [ { "id": 2284, - "logprob": -0.296875, + "logprob": -0.31323242, "special": false, "text": "():" }, @@ -38,13 +38,13 @@ }, { "id": 10914, - "logprob": -0.7734375, + "logprob": -0.7871094, "special": false, "text": " World" }, { "id": 16013, - "logprob": -0.61816406, + "logprob": -0.64746094, "special": false, "text": "!\")" }, @@ -62,7 +62,7 @@ }, { "id": 610, - "logprob": -0.4152832, + "logprob": -0.41064453, "special": false, "text": "def" }, @@ -92,7 +92,7 @@ }, { "id": 444, - "logprob": -0.21618652, + "logprob": -0.21655273, "special": false, "text": "name" }, @@ -139,28 +139,16 @@ "text": "Hello" }, { - "id": 925, - "logprob": -3.3476562, + "id": 332, + "logprob": -0.034698486, "special": false, - "text": " %" + "text": " \"" }, { - "id": 120, + "id": 494, "logprob": 0.0, "special": false, - "text": "s" - }, - { - "id": 11571, - "logprob": -0.08892822, - "special": false, - "text": "!\"" - }, - { - "id": 925, - "logprob": 0.0, - "special": false, - "text": " %" + "text": " +" }, { "id": 655, @@ -169,10 +157,22 @@ "text": " name" }, { - "id": 46, + "id": 494, + "logprob": -0.20141602, + "special": false, + "text": " +" + }, + { + "id": 332, "logprob": 0.0, "special": false, - "text": ")" + "text": " \"" + }, + { + "id": 16013, + "logprob": 0.0, + "special": false, + "text": "!\")" }, { "id": 222, @@ -230,7 +230,7 @@ }, { "id": 400, - "logprob": -0.074279785, + "logprob": 0.0, "special": false, "text": "age" }, @@ -289,22 +289,34 @@ "text": "Hello" }, { - "id": 925, + "id": 332, "logprob": 0.0, "special": false, - "text": " %" + "text": " \"" }, { - "id": 120, + "id": 494, "logprob": 0.0, "special": false, - "text": "s" + "text": " +" }, { - "id": 49, - "logprob": -0.07891846, + "id": 655, + "logprob": 0.0, "special": false, - "text": "," + "text": " name" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 3021, + "logprob": -0.5761719, + "special": false, + "text": " \"," }, { "id": 863, @@ -319,55 +331,43 @@ "text": " are" }, { - "id": 925, + "id": 332, "logprob": 0.0, "special": false, - "text": " %" + "text": " \"" }, { - "id": 105, + "id": 494, "logprob": 0.0, "special": false, - "text": "d" + "text": " +" }, { - "id": 11339, + "id": 615, "logprob": 0.0, "special": false, - "text": " years" + "text": " str" }, { - "id": 3627, + "id": 45, "logprob": 0.0, "special": false, - "text": " old" + "text": "(" }, { - "id": 11571, + "id": 400, "logprob": 0.0, "special": false, - "text": "!\"" + "text": "age" }, { - "id": 925, + "id": 46, "logprob": 0.0, "special": false, - "text": " %" - }, - { - "id": 327, - "logprob": 0.0, - "special": false, - "text": " (" - }, - { - "id": 444, - "logprob": 0.0, - "special": false, - "text": "name" + "text": ")" } ], "top_tokens": null }, - "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello %s!\" % name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello %s, you are %d years old!\" % (name" + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" } diff --git a/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py index 7cc82a4e..a0b0416b 100644 --- a/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py +++ b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py @@ -64,7 +64,7 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params( assert response.details.generated_tokens == 10 assert ( response.generated_text - == "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" + == "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" ) assert response == response_snapshot diff --git a/server/Makefile-flashinfer b/server/Makefile-flashinfer index f0a27622..d5f684ba 100644 --- a/server/Makefile-flashinfer +++ b/server/Makefile-flashinfer @@ -1,2 +1,5 @@ install-flashinfer: - pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 + # We need fsspec as an additional dependency, but + # `pip install flashinfer` cannot resolve it. + pip install fsspec + pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4 diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 3038602e..7b5af3c4 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -60,8 +60,7 @@ def paged_attention( from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( - # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. - query.contiguous(), + query, paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, @@ -231,8 +230,7 @@ def attention( softcap = 0.0 return prefill_with_paged_kv_state.get().forward( - # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. - query.contiguous(), + query, causal=causal, paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 26a72d9b..909eea27 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state( num_kv_heads: int, head_size: int, page_size: int, - dtype: torch.dtype, + kv_dtype: torch.dtype, + q_dtype: torch.dtype, window_left: int, ): """ @@ -91,9 +92,10 @@ def use_prefill_with_paged_kv_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=dtype, + kv_data_type=kv_dtype, + q_data_type=q_dtype, page_size=page_size, - window_left=window_left, + window_left=-1 if window_left is None else window_left, ) yield finally: @@ -113,41 +115,6 @@ def create_prefill_state( ) -@contextmanager -def use_prefill_state( - *, - state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, - cu_seqlens: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - token = prefill_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - kv_indptr=cu_seqlens, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_state.reset(token) - - def create_decode_state( *, device: torch.device, @@ -205,7 +172,7 @@ def use_decode_state( head_size: int, page_size: int, kv_cache_dtype: torch.dtype, - dtype: torch.dtype, + q_dtype: torch.dtype, window_left: int, ): """ @@ -242,8 +209,8 @@ def use_decode_state( head_dim=head_size, page_size=page_size, data_type=kv_cache_dtype, - q_data_type=dtype, - window_left=window_left, + q_data_type=q_dtype, + window_left=-1 if window_left is None else window_left, ) yield finally: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5d376990..c63ca1db 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2480,7 +2480,8 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, - dtype=self.dtype, + kv_dtype=self.kv_cache_dtype, + q_dtype=self.dtype, window_left=self.sliding_window, ) else: @@ -2494,6 +2495,6 @@ class FlashCausalLM(Model): head_size=self.head_size, page_size=BLOCK_SIZE, kv_cache_dtype=self.kv_cache_dtype, - dtype=self.dtype, + q_dtype=self.dtype, window_left=self.sliding_window, )