mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Basic flashinfer 0.2 support (#2862)
* Basic flashinfer 0.2 support This change does not use any of the new features yet, but makes some small compatibility changes. * Update to flashinfer 0.2.0.post1 * flashinfer: remove `contiguous` calls * Fix flashinfer install * flashinfer: fixup kv cache dtype * Fix some annoying perturbations * More output changes
This commit is contained in:
parent
afb6c728d8
commit
a9c7d2e3b6
@ -978,15 +978,16 @@
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1732218602,
|
||||
"narHash": "sha256-BElslL34KjOJCFMPkNtilOz6S/7iY7Vd72FNbRRWKDY=",
|
||||
"lastModified": 1736179589,
|
||||
"narHash": "sha256-/zZCSieBJncVXqOFbvbSov76g2eWAxVxEJNNA6SmQKc=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "f79638ac4e420e661321261744e745a3a747e182",
|
||||
"rev": "fc7ff53b2cd5c984ad1434f20c271e3b7600d1c4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "flashinfer-v0.2",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/flashinfer-v0.2";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
|
@ -32,7 +32,7 @@
|
||||
},
|
||||
{
|
||||
"id": 1101,
|
||||
"logprob": -1.0947266,
|
||||
"logprob": -1.0136719,
|
||||
"special": false,
|
||||
"text": " also"
|
||||
},
|
||||
@ -56,13 +56,13 @@
|
||||
},
|
||||
{
|
||||
"id": 4009,
|
||||
"logprob": -0.15563965,
|
||||
"logprob": -0.21923828,
|
||||
"special": false,
|
||||
"text": " network"
|
||||
},
|
||||
{
|
||||
"id": 477,
|
||||
"logprob": -1.4003906,
|
||||
"logprob": -1.4824219,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
}
|
||||
|
@ -8,7 +8,7 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1939,
|
||||
"logprob": -2.2675781,
|
||||
"logprob": -2.2460938,
|
||||
"special": false,
|
||||
"text": "?\n\n"
|
||||
},
|
||||
@ -20,13 +20,13 @@
|
||||
},
|
||||
{
|
||||
"id": 20909,
|
||||
"logprob": -0.37695312,
|
||||
"logprob": -0.48608398,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 4102,
|
||||
"logprob": -1.9316406,
|
||||
"logprob": -2.265625,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
@ -38,25 +38,13 @@
|
||||
},
|
||||
{
|
||||
"id": 458,
|
||||
"logprob": -0.80859375,
|
||||
"logprob": -0.6328125,
|
||||
"special": false,
|
||||
"text": " an"
|
||||
},
|
||||
{
|
||||
"id": 3082,
|
||||
"logprob": -1.4541016,
|
||||
"special": false,
|
||||
"text": " area"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 20443,
|
||||
"logprob": -0.5136719,
|
||||
"logprob": -0.1796875,
|
||||
"special": false,
|
||||
"text": " artificial"
|
||||
},
|
||||
@ -65,9 +53,21 @@
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " intelligence"
|
||||
},
|
||||
{
|
||||
"id": 320,
|
||||
"logprob": -0.37695312,
|
||||
"special": false,
|
||||
"text": " ("
|
||||
},
|
||||
{
|
||||
"id": 15469,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "AI"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is deep learning?\n\nDeep Learning is an area of artificial intelligence"
|
||||
"generated_text": "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI"
|
||||
}
|
||||
|
@ -9,61 +9,61 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18183,
|
||||
"logprob": -1.6669922,
|
||||
"logprob": -1.4912109,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.08959961,
|
||||
"logprob": -0.075683594,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.14685059,
|
||||
"logprob": -0.12408447,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.125,
|
||||
"logprob": -0.12768555,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 25993,
|
||||
"logprob": -0.81640625,
|
||||
"logprob": -0.82128906,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.0013418198,
|
||||
"logprob": -0.0012636185,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5662,
|
||||
"logprob": -0.16259766,
|
||||
"logprob": -0.12878418,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.0016393661,
|
||||
"logprob": -0.0015888214,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 429,
|
||||
"logprob": -0.4477539,
|
||||
"logprob": -0.49194336,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5711,
|
||||
"logprob": -1.2802734,
|
||||
"logprob": -1.2626953,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
@ -82,61 +82,61 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18183,
|
||||
"logprob": -1.6669922,
|
||||
"logprob": -1.4912109,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.08959961,
|
||||
"logprob": -0.075683594,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.14685059,
|
||||
"logprob": -0.12408447,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.125,
|
||||
"logprob": -0.12768555,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 25993,
|
||||
"logprob": -0.81640625,
|
||||
"logprob": -0.82128906,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.0013418198,
|
||||
"logprob": -0.0012636185,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5662,
|
||||
"logprob": -0.16259766,
|
||||
"logprob": -0.12878418,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.0016393661,
|
||||
"logprob": -0.0015888214,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 429,
|
||||
"logprob": -0.4477539,
|
||||
"logprob": -0.49194336,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5711,
|
||||
"logprob": -1.2802734,
|
||||
"logprob": -1.2626953,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
@ -155,61 +155,61 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18183,
|
||||
"logprob": -1.6669922,
|
||||
"logprob": -1.4912109,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.08959961,
|
||||
"logprob": -0.075683594,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.14685059,
|
||||
"logprob": -0.12408447,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.125,
|
||||
"logprob": -0.12768555,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 25993,
|
||||
"logprob": -0.81640625,
|
||||
"logprob": -0.82128906,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.0013418198,
|
||||
"logprob": -0.0012636185,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5662,
|
||||
"logprob": -0.16259766,
|
||||
"logprob": -0.12878418,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.0016393661,
|
||||
"logprob": -0.0015888214,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 429,
|
||||
"logprob": -0.4477539,
|
||||
"logprob": -0.49194336,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5711,
|
||||
"logprob": -1.2802734,
|
||||
"logprob": -1.2626953,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
@ -228,61 +228,61 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18183,
|
||||
"logprob": -1.6669922,
|
||||
"logprob": -1.4912109,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.08959961,
|
||||
"logprob": -0.075683594,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.14685059,
|
||||
"logprob": -0.12408447,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.125,
|
||||
"logprob": -0.12768555,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 25993,
|
||||
"logprob": -0.81640625,
|
||||
"logprob": -0.82128906,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.0013418198,
|
||||
"logprob": -0.0012636185,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5662,
|
||||
"logprob": -0.16259766,
|
||||
"logprob": -0.12878418,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6832,
|
||||
"logprob": -0.0016393661,
|
||||
"logprob": -0.0015888214,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 429,
|
||||
"logprob": -0.4477539,
|
||||
"logprob": -0.49194336,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5711,
|
||||
"logprob": -1.2802734,
|
||||
"logprob": -1.2626953,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
|
@ -44,7 +44,7 @@
|
||||
},
|
||||
{
|
||||
"id": 38397,
|
||||
"logprob": -0.12695312,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
|
@ -14,60 +14,60 @@
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -0.18493652,
|
||||
"logprob": -0.19030762,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 16819,
|
||||
"logprob": -1.4804688,
|
||||
"logprob": -1.4863281,
|
||||
"special": false,
|
||||
"text": " detection"
|
||||
},
|
||||
{
|
||||
"id": 576,
|
||||
"logprob": -0.7011719,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 8566,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " presence"
|
||||
},
|
||||
{
|
||||
"id": 689,
|
||||
"logprob": -0.16491699,
|
||||
"special": false,
|
||||
"text": " or"
|
||||
},
|
||||
{
|
||||
"id": 14862,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " absence"
|
||||
},
|
||||
{
|
||||
"id": 576,
|
||||
"logprob": -0.9970703,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 671,
|
||||
"logprob": -2.1738281,
|
||||
"logprob": -0.5292969,
|
||||
"special": false,
|
||||
"text": " an"
|
||||
},
|
||||
{
|
||||
"id": 24646,
|
||||
"logprob": -3.0449219,
|
||||
"special": false,
|
||||
"text": " RNA"
|
||||
},
|
||||
{
|
||||
"id": 12369,
|
||||
"logprob": -0.19299316,
|
||||
"special": false,
|
||||
"text": " virus"
|
||||
},
|
||||
{
|
||||
"id": 575,
|
||||
"logprob": -0.10632324,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 6022,
|
||||
"logprob": -0.98095703,
|
||||
"special": false,
|
||||
"text": " patients"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -1.3095703,
|
||||
"special": false,
|
||||
"text": " who"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request for the detection of an RNA virus in patients who"
|
||||
"generated_text": "Test request for the detection of the presence or absence of an"
|
||||
}
|
||||
|
@ -8,7 +8,7 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.296875,
|
||||
"logprob": -0.31323242,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
@ -38,13 +38,13 @@
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -0.7734375,
|
||||
"logprob": -0.7871094,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.61816406,
|
||||
"logprob": -0.64746094,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
@ -62,7 +62,7 @@
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -0.4152832,
|
||||
"logprob": -0.41064453,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
@ -92,7 +92,7 @@
|
||||
},
|
||||
{
|
||||
"id": 444,
|
||||
"logprob": -0.21618652,
|
||||
"logprob": -0.21655273,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
@ -139,28 +139,16 @@
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 925,
|
||||
"logprob": -3.3476562,
|
||||
"id": 332,
|
||||
"logprob": -0.034698486,
|
||||
"special": false,
|
||||
"text": " %"
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 120,
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 11571,
|
||||
"logprob": -0.08892822,
|
||||
"special": false,
|
||||
"text": "!\""
|
||||
},
|
||||
{
|
||||
"id": 925,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " %"
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 655,
|
||||
@ -169,10 +157,22 @@
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 46,
|
||||
"id": 494,
|
||||
"logprob": -0.20141602,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 332,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
@ -230,7 +230,7 @@
|
||||
},
|
||||
{
|
||||
"id": 400,
|
||||
"logprob": -0.074279785,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "age"
|
||||
},
|
||||
@ -289,22 +289,34 @@
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 925,
|
||||
"id": 332,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " %"
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 120,
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 49,
|
||||
"logprob": -0.07891846,
|
||||
"id": 655,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 3021,
|
||||
"logprob": -0.5761719,
|
||||
"special": false,
|
||||
"text": " \","
|
||||
},
|
||||
{
|
||||
"id": 863,
|
||||
@ -319,55 +331,43 @@
|
||||
"text": " are"
|
||||
},
|
||||
{
|
||||
"id": 925,
|
||||
"id": 332,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " %"
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 105,
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "d"
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 11339,
|
||||
"id": 615,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " years"
|
||||
"text": " str"
|
||||
},
|
||||
{
|
||||
"id": 3627,
|
||||
"id": 45,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " old"
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 11571,
|
||||
"id": 400,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "!\""
|
||||
"text": "age"
|
||||
},
|
||||
{
|
||||
"id": 925,
|
||||
"id": 46,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " %"
|
||||
},
|
||||
{
|
||||
"id": 327,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " ("
|
||||
},
|
||||
{
|
||||
"id": 444,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
"text": ")"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello %s!\" % name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello %s, you are %d years old!\" % (name"
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)"
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params(
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "What is deep learning?\n\nDeep Learning is an area of artificial intelligence"
|
||||
== "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
@ -1,2 +1,5 @@
|
||||
install-flashinfer:
|
||||
pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4
|
||||
# We need fsspec as an additional dependency, but
|
||||
# `pip install flashinfer` cannot resolve it.
|
||||
pip install fsspec
|
||||
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4
|
||||
|
@ -60,8 +60,7 @@ def paged_attention(
|
||||
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||
|
||||
return decode_state.get().forward(
|
||||
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
||||
query.contiguous(),
|
||||
query,
|
||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
@ -231,8 +230,7 @@ def attention(
|
||||
softcap = 0.0
|
||||
|
||||
return prefill_with_paged_kv_state.get().forward(
|
||||
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
||||
query.contiguous(),
|
||||
query,
|
||||
causal=causal,
|
||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||
logits_soft_cap=softcap,
|
||||
|
@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_dtype: torch.dtype,
|
||||
q_dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
@ -91,9 +92,10 @@ def use_prefill_with_paged_kv_state(
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_dtype,
|
||||
q_data_type=q_dtype,
|
||||
page_size=page_size,
|
||||
window_left=window_left,
|
||||
window_left=-1 if window_left is None else window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
@ -113,41 +115,6 @@ def create_prefill_state(
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_prefill_state(
|
||||
*,
|
||||
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
||||
cu_seqlens: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`attention` function while the context manager is active.
|
||||
"""
|
||||
|
||||
token = prefill_state.set(state)
|
||||
try:
|
||||
state.begin_forward(
|
||||
qo_indptr=cu_seqlens,
|
||||
kv_indptr=cu_seqlens,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
prefill_state.reset(token)
|
||||
|
||||
|
||||
def create_decode_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
@ -205,7 +172,7 @@ def use_decode_state(
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
q_dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
@ -242,8 +209,8 @@ def use_decode_state(
|
||||
head_dim=head_size,
|
||||
page_size=page_size,
|
||||
data_type=kv_cache_dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
q_data_type=q_dtype,
|
||||
window_left=-1 if window_left is None else window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
|
@ -2480,7 +2480,8 @@ class FlashCausalLM(Model):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
dtype=self.dtype,
|
||||
kv_dtype=self.kv_cache_dtype,
|
||||
q_dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
@ -2494,6 +2495,6 @@ class FlashCausalLM(Model):
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
dtype=self.dtype,
|
||||
q_dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user