Basic flashinfer 0.2 support (#2862)

* Basic flashinfer 0.2 support

This change does not use any of the new features yet, but makes
some small compatibility changes.

* Update to flashinfer 0.2.0.post1

* flashinfer: remove `contiguous` calls

* Fix flashinfer install

* flashinfer: fixup kv cache dtype

* Fix some annoying perturbations

* More output changes
This commit is contained in:
Daniël de Kok 2025-01-09 16:25:00 +01:00 committed by GitHub
parent afb6c728d8
commit a9c7d2e3b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 177 additions and 207 deletions

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1732218602, "lastModified": 1736179589,
"narHash": "sha256-BElslL34KjOJCFMPkNtilOz6S/7iY7Vd72FNbRRWKDY=", "narHash": "sha256-/zZCSieBJncVXqOFbvbSov76g2eWAxVxEJNNA6SmQKc=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "f79638ac4e420e661321261744e745a3a747e182", "rev": "fc7ff53b2cd5c984ad1434f20c271e3b7600d1c4",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "huggingface", "owner": "huggingface",
"ref": "flashinfer-v0.2",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix"; tgi-nix.url = "github:huggingface/text-generation-inference-nix/flashinfer-v0.2";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {

View File

@ -32,7 +32,7 @@
}, },
{ {
"id": 1101, "id": 1101,
"logprob": -1.0947266, "logprob": -1.0136719,
"special": false, "special": false,
"text": " also" "text": " also"
}, },
@ -56,13 +56,13 @@
}, },
{ {
"id": 4009, "id": 4009,
"logprob": -0.15563965, "logprob": -0.21923828,
"special": false, "special": false,
"text": " network" "text": " network"
}, },
{ {
"id": 477, "id": 477,
"logprob": -1.4003906, "logprob": -1.4824219,
"special": false, "special": false,
"text": " or" "text": " or"
} }

View File

@ -8,7 +8,7 @@
"tokens": [ "tokens": [
{ {
"id": 1939, "id": 1939,
"logprob": -2.2675781, "logprob": -2.2460938,
"special": false, "special": false,
"text": "?\n\n" "text": "?\n\n"
}, },
@ -20,13 +20,13 @@
}, },
{ {
"id": 20909, "id": 20909,
"logprob": -0.37695312, "logprob": -0.48608398,
"special": false, "special": false,
"text": " Learning" "text": " Learning"
}, },
{ {
"id": 4102, "id": 4102,
"logprob": -1.9316406, "logprob": -2.265625,
"special": false, "special": false,
"text": " " "text": " "
}, },
@ -38,25 +38,13 @@
}, },
{ {
"id": 458, "id": 458,
"logprob": -0.80859375, "logprob": -0.6328125,
"special": false, "special": false,
"text": " an" "text": " an"
}, },
{
"id": 3082,
"logprob": -1.4541016,
"special": false,
"text": " area"
},
{
"id": 315,
"logprob": 0.0,
"special": false,
"text": " of"
},
{ {
"id": 20443, "id": 20443,
"logprob": -0.5136719, "logprob": -0.1796875,
"special": false, "special": false,
"text": " artificial" "text": " artificial"
}, },
@ -65,9 +53,21 @@
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " intelligence" "text": " intelligence"
},
{
"id": 320,
"logprob": -0.37695312,
"special": false,
"text": " ("
},
{
"id": 15469,
"logprob": 0.0,
"special": false,
"text": "AI"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" "generated_text": "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI"
} }

View File

@ -9,61 +9,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -82,61 +82,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -155,61 +155,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -228,61 +228,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.4912109,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.075683594,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.12408447,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.12768555,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.82128906,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0012636185,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16259766, "logprob": -0.12878418,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0015888214,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.49194336,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2626953,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }

View File

@ -44,7 +44,7 @@
}, },
{ {
"id": 38397, "id": 38397,
"logprob": -0.12695312, "logprob": 0.0,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },

View File

@ -14,60 +14,60 @@
}, },
{ {
"id": 573, "id": 573,
"logprob": -0.18493652, "logprob": -0.19030762,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 16819, "id": 16819,
"logprob": -1.4804688, "logprob": -1.4863281,
"special": false, "special": false,
"text": " detection" "text": " detection"
}, },
{ {
"id": 576, "id": 576,
"logprob": -0.7011719, "logprob": -0.7089844,
"special": false,
"text": " of"
},
{
"id": 573,
"logprob": -2.0410156,
"special": false,
"text": " the"
},
{
"id": 8566,
"logprob": 0.0,
"special": false,
"text": " presence"
},
{
"id": 689,
"logprob": -0.16491699,
"special": false,
"text": " or"
},
{
"id": 14862,
"logprob": 0.0,
"special": false,
"text": " absence"
},
{
"id": 576,
"logprob": -0.9970703,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 671, "id": 671,
"logprob": -2.1738281, "logprob": -0.5292969,
"special": false, "special": false,
"text": " an" "text": " an"
},
{
"id": 24646,
"logprob": -3.0449219,
"special": false,
"text": " RNA"
},
{
"id": 12369,
"logprob": -0.19299316,
"special": false,
"text": " virus"
},
{
"id": 575,
"logprob": -0.10632324,
"special": false,
"text": " in"
},
{
"id": 6022,
"logprob": -0.98095703,
"special": false,
"text": " patients"
},
{
"id": 1064,
"logprob": -1.3095703,
"special": false,
"text": " who"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request for the detection of an RNA virus in patients who" "generated_text": "Test request for the detection of the presence or absence of an"
} }

View File

@ -8,7 +8,7 @@
"tokens": [ "tokens": [
{ {
"id": 2284, "id": 2284,
"logprob": -0.296875, "logprob": -0.31323242,
"special": false, "special": false,
"text": "():" "text": "():"
}, },
@ -38,13 +38,13 @@
}, },
{ {
"id": 10914, "id": 10914,
"logprob": -0.7734375, "logprob": -0.7871094,
"special": false, "special": false,
"text": " World" "text": " World"
}, },
{ {
"id": 16013, "id": 16013,
"logprob": -0.61816406, "logprob": -0.64746094,
"special": false, "special": false,
"text": "!\")" "text": "!\")"
}, },
@ -62,7 +62,7 @@
}, },
{ {
"id": 610, "id": 610,
"logprob": -0.4152832, "logprob": -0.41064453,
"special": false, "special": false,
"text": "def" "text": "def"
}, },
@ -92,7 +92,7 @@
}, },
{ {
"id": 444, "id": 444,
"logprob": -0.21618652, "logprob": -0.21655273,
"special": false, "special": false,
"text": "name" "text": "name"
}, },
@ -139,28 +139,16 @@
"text": "Hello" "text": "Hello"
}, },
{ {
"id": 925, "id": 332,
"logprob": -3.3476562, "logprob": -0.034698486,
"special": false, "special": false,
"text": " %" "text": " \""
}, },
{ {
"id": 120, "id": 494,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "s" "text": " +"
},
{
"id": 11571,
"logprob": -0.08892822,
"special": false,
"text": "!\""
},
{
"id": 925,
"logprob": 0.0,
"special": false,
"text": " %"
}, },
{ {
"id": 655, "id": 655,
@ -169,10 +157,22 @@
"text": " name" "text": " name"
}, },
{ {
"id": 46, "id": 494,
"logprob": -0.20141602,
"special": false,
"text": " +"
},
{
"id": 332,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": ")" "text": " \""
},
{
"id": 16013,
"logprob": 0.0,
"special": false,
"text": "!\")"
}, },
{ {
"id": 222, "id": 222,
@ -230,7 +230,7 @@
}, },
{ {
"id": 400, "id": 400,
"logprob": -0.074279785, "logprob": 0.0,
"special": false, "special": false,
"text": "age" "text": "age"
}, },
@ -289,22 +289,34 @@
"text": "Hello" "text": "Hello"
}, },
{ {
"id": 925, "id": 332,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " %" "text": " \""
}, },
{ {
"id": 120, "id": 494,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "s" "text": " +"
}, },
{ {
"id": 49, "id": 655,
"logprob": -0.07891846, "logprob": 0.0,
"special": false, "special": false,
"text": "," "text": " name"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 3021,
"logprob": -0.5761719,
"special": false,
"text": " \","
}, },
{ {
"id": 863, "id": 863,
@ -319,55 +331,43 @@
"text": " are" "text": " are"
}, },
{ {
"id": 925, "id": 332,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " %" "text": " \""
}, },
{ {
"id": 105, "id": 494,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "d" "text": " +"
}, },
{ {
"id": 11339, "id": 615,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " years" "text": " str"
}, },
{ {
"id": 3627, "id": 45,
"logprob": 0.0,
"special": false,
"text": " old"
},
{
"id": 11571,
"logprob": 0.0,
"special": false,
"text": "!\""
},
{
"id": 925,
"logprob": 0.0,
"special": false,
"text": " %"
},
{
"id": 327,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "(" "text": "("
}, },
{ {
"id": 444, "id": 400,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "name" "text": "age"
},
{
"id": 46,
"logprob": 0.0,
"special": false,
"text": ")"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello %s!\" % name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello %s, you are %d years old!\" % (name" "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)"
} }

View File

@ -64,7 +64,7 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params(
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert ( assert (
response.generated_text response.generated_text
== "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" == "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI"
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -1,2 +1,5 @@
install-flashinfer: install-flashinfer:
pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 # We need fsspec as an additional dependency, but
# `pip install flashinfer` cannot resolve it.
pip install fsspec
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4

View File

@ -60,8 +60,7 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward( return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. query,
query.contiguous(),
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
@ -231,8 +230,7 @@ def attention(
softcap = 0.0 softcap = 0.0
return prefill_with_paged_kv_state.get().forward( return prefill_with_paged_kv_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. query,
query.contiguous(),
causal=causal, causal=causal,
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,

View File

@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
page_size: int, page_size: int,
dtype: torch.dtype, kv_dtype: torch.dtype,
q_dtype: torch.dtype,
window_left: int, window_left: int,
): ):
""" """
@ -91,9 +92,10 @@ def use_prefill_with_paged_kv_state(
num_qo_heads=num_heads, num_qo_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_size, head_dim=head_size,
q_data_type=dtype, kv_data_type=kv_dtype,
q_data_type=q_dtype,
page_size=page_size, page_size=page_size,
window_left=window_left, window_left=-1 if window_left is None else window_left,
) )
yield yield
finally: finally:
@ -113,41 +115,6 @@ def create_prefill_state(
) )
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state( def create_decode_state(
*, *,
device: torch.device, device: torch.device,
@ -205,7 +172,7 @@ def use_decode_state(
head_size: int, head_size: int,
page_size: int, page_size: int,
kv_cache_dtype: torch.dtype, kv_cache_dtype: torch.dtype,
dtype: torch.dtype, q_dtype: torch.dtype,
window_left: int, window_left: int,
): ):
""" """
@ -242,8 +209,8 @@ def use_decode_state(
head_dim=head_size, head_dim=head_size,
page_size=page_size, page_size=page_size,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
q_data_type=dtype, q_data_type=q_dtype,
window_left=window_left, window_left=-1 if window_left is None else window_left,
) )
yield yield
finally: finally:

View File

@ -2480,7 +2480,8 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
dtype=self.dtype, kv_dtype=self.kv_cache_dtype,
q_dtype=self.dtype,
window_left=self.sliding_window, window_left=self.sliding_window,
) )
else: else:
@ -2494,6 +2495,6 @@ class FlashCausalLM(Model):
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype, q_dtype=self.dtype,
window_left=self.sliding_window, window_left=self.sliding_window,
) )