diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json index 8829f9fe6..e322446bb 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -24,13 +24,13 @@ "tokens": [ { "id": 1736, - "logprob": -2.03125, + "logprob": -2.109375, "special": false, "text": " form" }, { "id": 109, - "logprob": -1.8671875, + "logprob": -1.90625, "special": false, "text": "\n\n" }, @@ -42,48 +42,48 @@ }, { "id": 2121, - "logprob": -1.8125, + "logprob": -1.796875, "special": false, "text": " test" }, { "id": 3853, - "logprob": -0.24121094, + "logprob": -0.24511719, "special": false, "text": " request" }, { "id": 1736, - "logprob": -0.100097656, + "logprob": -0.09326172, "special": false, "text": " form" }, { "id": 603, - "logprob": -0.9453125, + "logprob": -0.95703125, "special": false, "text": " is" }, { - "id": 476, - "logprob": -1.703125, + "id": 1671, + "logprob": -1.5859375, "special": false, - "text": " a" + "text": " used" }, { - "id": 4551, - "logprob": -2.453125, + "id": 577, + "logprob": -0.39257812, "special": false, - "text": " document" + "text": " to" }, { - "id": 674, - "logprob": -0.796875, + "id": 3853, + "logprob": -1.25, "special": false, - "text": " that" + "text": " request" } ], "top_tokens": null }, - "generated_text": " form\n\nThe test request form is a document that" + "generated_text": " form\n\nThe test request form is used to request" } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json index bc80a0f91..a7019a43a 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -11,12 +11,12 @@ }, { "id": 2015, - "logprob": -9.640625, + "logprob": -9.6484375, "text": "Test" }, { "id": 3853, - "logprob": -10.375, + "logprob": -10.3671875, "text": " request" } ], @@ -24,19 +24,19 @@ "tokens": [ { "id": 604, - "logprob": -0.2824707, + "logprob": -0.28271484, "special": false, "text": " for" }, { "id": 573, - "logprob": -0.19030762, + "logprob": -0.18493652, "special": false, "text": " the" }, { "id": 16819, - "logprob": -1.4892578, + "logprob": -1.4804688, "special": false, "text": " detection" }, @@ -46,44 +46,44 @@ "special": false, "text": " of" }, - { - "id": 573, - "logprob": -2.0195312, - "special": false, - "text": " the" - }, - { - "id": 8566, - "logprob": 0.0, - "special": false, - "text": " presence" - }, - { - "id": 689, - "logprob": -0.16491699, - "special": false, - "text": " or" - }, - { - "id": 14862, - "logprob": 0.0, - "special": false, - "text": " absence" - }, - { - "id": 576, - "logprob": -0.9946289, - "special": false, - "text": " of" - }, { "id": 671, - "logprob": -0.5263672, + "logprob": -2.1738281, "special": false, "text": " an" + }, + { + "id": 24646, + "logprob": -3.0449219, + "special": false, + "text": " RNA" + }, + { + "id": 12369, + "logprob": -0.19299316, + "special": false, + "text": " virus" + }, + { + "id": 575, + "logprob": -0.10632324, + "special": false, + "text": " in" + }, + { + "id": 6022, + "logprob": -0.98095703, + "special": false, + "text": " patients" + }, + { + "id": 1064, + "logprob": -1.3095703, + "special": false, + "text": " who" } ], "top_tokens": null }, - "generated_text": "Test request for the detection of the presence or absence of an" + "generated_text": "Test request for the detection of an RNA virus in patients who" } diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index e1ef62c5b..5f8954ea9 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -152,11 +152,13 @@ def create_decode_state( ): """Create a decode state.""" workspace_buffer = get_workspace(device) + num_groups = num_heads // num_kv_heads return flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout="NHD", use_cuda_graph=False, - use_tensor_cores=num_heads // num_kv_heads > 4, + # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 + use_tensor_cores=num_groups not in [1, 2, 4, 8], ) @@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs( therefore stored as part of the state. """ workspace_buffer = get_workspace(device) + num_groups = num_heads // num_kv_heads return flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout="NHD", @@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs( paged_kv_indices_buffer=block_tables, paged_kv_indptr_buffer=block_tables_ptr, paged_kv_last_page_len_buffer=last_page_len, - use_tensor_cores=num_heads // num_kv_heads > 4, + # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 + use_tensor_cores=num_groups not in [1, 2, 4, 8], )