From e7e07342bd2315ecab1968bcd0f1d0fb298f3ddb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 1 Dec 2023 18:49:01 +0000 Subject: [PATCH] Working state except all params ?? --- .../test_flash_medusa_all_params.json | 90 ++- .../test_flash_medusa_load.json | 604 +++++++++++------- .../test_flash_medusa_simple.json | 151 +++-- integration-tests/models/test_flash_medusa.py | 10 +- .../models/flash_causal_lm.py | 4 +- server/text_generation_server/utils/tokens.py | 3 - 6 files changed, 518 insertions(+), 344 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json index e9b1c57a..ad4c6c30 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "stop_sequence", - "generated_tokens": 5, + "finish_reason": "length", + "generated_tokens": 10, "prefill": [ { "id": 1, @@ -10,49 +10,89 @@ "text": "" }, { - "id": 4321, - "logprob": -10.0625, - "text": "Test" + "id": 338, + "logprob": -10.0078125, + "text": "is" }, { - "id": 2009, - "logprob": -12.28125, - "text": "request" + "id": 21784, + "logprob": -15.515625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -2.8847656, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -4.140625, + "text": "?" } ], "seed": 0, "tokens": [ { - "id": 5229, - "logprob": -1.7587891, + "id": 13, + "logprob": -1.1582031, "special": false, - "text": " failed" + "text": "\n" }, { - "id": 363, - "logprob": -0.5175781, - "special": false, - "text": " for" - }, - { - "id": 1404, + "id": 2772, "logprob": 0.0, "special": false, - "text": " user" + "text": "De" }, { - "id": 376, + "id": 1022, "logprob": 0.0, "special": false, - "text": " \"" + "text": "ep" }, { - "id": 1688, - "logprob": -0.20422363, + "id": 6509, + "logprob": 0.0, "special": false, - "text": "test" + "text": " learning" + }, + { + "id": 313, + "logprob": -1.0712891, + "special": false, + "text": " (" + }, + { + "id": 15189, + "logprob": -0.7578125, + "special": false, + "text": "also" + }, + { + "id": 2998, + "logprob": 0.0, + "special": false, + "text": " known" + }, + { + "id": 408, + "logprob": 0.0, + "special": false, + "text": " as" + }, + { + "id": 6483, + "logprob": 0.0, + "special": false, + "text": " deep" + }, + { + "id": 19677, + "logprob": 0.0, + "special": false, + "text": " neural" } ] }, - "generated_text": "Test request failed for user \"test" + "generated_text": "What is Deep Learning?\nDeep learning (also known as deep neural" } diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json index 80d4873a..82a7b9e1 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json @@ -11,81 +11,108 @@ "text": "" }, { - "id": 4321, - "logprob": -10.0625, - "text": "Test" + "id": 1724, + "logprob": -10.734375, + "text": "What" }, { - "id": 2009, - "logprob": -12.28125, - "text": "request" + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2753906, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.48046875, + "text": "?" } ], "seed": null, "tokens": [ - { - "id": 363, - "logprob": -2.0878906, - "special": false, - "text": " for" - }, - { - "id": 278, - "logprob": -3.4082031, - "special": false, - "text": " the" - }, - { - "id": 376, - "logprob": -3.8457031, - "special": false, - "text": " \"" - }, - { - "id": 2577, - "logprob": -3.5605469, - "special": false, - "text": "Get" - }, - { - "id": 599, - "logprob": -3.4707031, - "special": false, - "text": " all" - }, - { - "id": 4160, - "logprob": -3.2421875, - "special": false, - "text": " users" - }, - { - "id": 29908, - "logprob": -0.49072266, - "special": false, - "text": "\"" - }, - { - "id": 16248, - "logprob": -1.2353516, - "special": false, - "text": " endpoint" - }, - { - "id": 29889, - "logprob": -0.8833008, - "special": false, - "text": "." - }, { "id": 13, - "logprob": -0.42089844, + "logprob": -1.1845703, "special": false, "text": "\n" + }, + { + "id": 2772, + "logprob": -0.5727539, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.00010967255, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.04510498, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.00020992756, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.0046539307, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025844574, + "special": false, + "text": " learning" + }, + { + "id": 393, + "logprob": -0.09185791, + "special": false, + "text": " that" + }, + { + "id": 20789, + "logprob": -0.4951172, + "special": false, + "text": " involves" } ] }, - "generated_text": " for the \"Get all users\" endpoint.\n" + "generated_text": "ep learning is a subset of machine learning that involves" }, { "details": { @@ -99,81 +126,108 @@ "text": "" }, { - "id": 4321, - "logprob": -10.0625, - "text": "Test" + "id": 1724, + "logprob": -10.734375, + "text": "What" }, { - "id": 2009, - "logprob": -12.28125, - "text": "request" + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" } ], "seed": null, "tokens": [ - { - "id": 363, - "logprob": -2.0878906, - "special": false, - "text": " for" - }, - { - "id": 278, - "logprob": -3.4082031, - "special": false, - "text": " the" - }, - { - "id": 376, - "logprob": -3.8457031, - "special": false, - "text": " \"" - }, - { - "id": 2577, - "logprob": -3.5625, - "special": false, - "text": "Get" - }, - { - "id": 599, - "logprob": -3.4726562, - "special": false, - "text": " all" - }, - { - "id": 4160, - "logprob": -3.2382812, - "special": false, - "text": " users" - }, - { - "id": 29908, - "logprob": -0.49047852, - "special": false, - "text": "\"" - }, - { - "id": 16248, - "logprob": -1.2412109, - "special": false, - "text": " endpoint" - }, - { - "id": 29889, - "logprob": -0.87402344, - "special": false, - "text": "." - }, { "id": 13, - "logprob": -0.41723633, + "logprob": -1.1826172, "special": false, "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + }, + { + "id": 393, + "logprob": -0.091918945, + "special": false, + "text": " that" + }, + { + "id": 20789, + "logprob": -0.50097656, + "special": false, + "text": " involves" } ] }, - "generated_text": " for the \"Get all users\" endpoint.\n" + "generated_text": "ep learning is a subset of machine learning that involves" }, { "details": { @@ -187,81 +241,108 @@ "text": "" }, { - "id": 4321, - "logprob": -10.0625, - "text": "Test" + "id": 1724, + "logprob": -10.734375, + "text": "What" }, { - "id": 2009, - "logprob": -12.28125, - "text": "request" + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" } ], "seed": null, "tokens": [ - { - "id": 363, - "logprob": -2.0878906, - "special": false, - "text": " for" - }, - { - "id": 278, - "logprob": -3.4082031, - "special": false, - "text": " the" - }, - { - "id": 376, - "logprob": -3.8457031, - "special": false, - "text": " \"" - }, - { - "id": 2577, - "logprob": -3.5605469, - "special": false, - "text": "Get" - }, - { - "id": 599, - "logprob": -3.4707031, - "special": false, - "text": " all" - }, - { - "id": 4160, - "logprob": -3.2421875, - "special": false, - "text": " users" - }, - { - "id": 29908, - "logprob": -0.49072266, - "special": false, - "text": "\"" - }, - { - "id": 16248, - "logprob": -1.2353516, - "special": false, - "text": " endpoint" - }, - { - "id": 29889, - "logprob": -0.8833008, - "special": false, - "text": "." - }, { "id": 13, - "logprob": -0.42089844, + "logprob": -1.1826172, "special": false, "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + }, + { + "id": 393, + "logprob": -0.091918945, + "special": false, + "text": " that" + }, + { + "id": 20789, + "logprob": -0.50097656, + "special": false, + "text": " involves" } ] }, - "generated_text": " for the \"Get all users\" endpoint.\n" + "generated_text": "ep learning is a subset of machine learning that involves" }, { "details": { @@ -275,80 +356,107 @@ "text": "" }, { - "id": 4321, - "logprob": -10.0625, - "text": "Test" + "id": 1724, + "logprob": -10.734375, + "text": "What" }, { - "id": 2009, - "logprob": -12.28125, - "text": "request" + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" } ], "seed": null, "tokens": [ - { - "id": 363, - "logprob": -2.0878906, - "special": false, - "text": " for" - }, - { - "id": 278, - "logprob": -3.4082031, - "special": false, - "text": " the" - }, - { - "id": 376, - "logprob": -3.8457031, - "special": false, - "text": " \"" - }, - { - "id": 2577, - "logprob": -3.5605469, - "special": false, - "text": "Get" - }, - { - "id": 599, - "logprob": -3.4707031, - "special": false, - "text": " all" - }, - { - "id": 4160, - "logprob": -3.2421875, - "special": false, - "text": " users" - }, - { - "id": 29908, - "logprob": -0.49072266, - "special": false, - "text": "\"" - }, - { - "id": 16248, - "logprob": -1.2353516, - "special": false, - "text": " endpoint" - }, - { - "id": 29889, - "logprob": -0.8833008, - "special": false, - "text": "." - }, { "id": 13, - "logprob": -0.42089844, + "logprob": -1.1826172, "special": false, "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + }, + { + "id": 393, + "logprob": -0.091918945, + "special": false, + "text": " that" + }, + { + "id": 20789, + "logprob": -0.50097656, + "special": false, + "text": " involves" } ] }, - "generated_text": " for the \"Get all users\" endpoint.\n" + "generated_text": "ep learning is a subset of machine learning that involves" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json index eb449de3..0a1e3198 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json @@ -10,79 +10,106 @@ "text": "" }, { - "id": 4321, - "logprob": -10.0625, - "text": "Test" + "id": 1724, + "logprob": -10.734375, + "text": "What" }, { - "id": 2009, - "logprob": -12.28125, - "text": "request" + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2753906, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.48046875, + "text": "?" } ], "seed": null, "tokens": [ - { - "id": 363, - "logprob": -2.0878906, - "special": false, - "text": " for" - }, - { - "id": 278, - "logprob": -3.4121094, - "special": false, - "text": " the" - }, - { - "id": 376, - "logprob": -3.8457031, - "special": false, - "text": " \"" - }, - { - "id": 2577, - "logprob": -3.5566406, - "special": false, - "text": "Get" - }, - { - "id": 599, - "logprob": -3.4746094, - "special": false, - "text": " all" - }, - { - "id": 4160, - "logprob": -3.2363281, - "special": false, - "text": " users" - }, - { - "id": 29908, - "logprob": -0.49023438, - "special": false, - "text": "\"" - }, - { - "id": 16248, - "logprob": -1.2402344, - "special": false, - "text": " endpoint" - }, - { - "id": 29889, - "logprob": -0.88134766, - "special": false, - "text": "." - }, { "id": 13, - "logprob": -0.41870117, + "logprob": -1.1845703, "special": false, "text": "\n" + }, + { + "id": 2772, + "logprob": -0.5727539, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108122826, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.01852417, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004787445, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00026226044, + "special": false, + "text": " learning" + }, + { + "id": 393, + "logprob": -0.09161377, + "special": false, + "text": " that" + }, + { + "id": 20789, + "logprob": -0.49560547, + "special": false, + "text": " involves" } ] }, - "generated_text": " for the \"Get all users\" endpoint.\n" + "generated_text": "ep learning is a subset of machine learning that involves" } diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index 7cc797e4..b48914b8 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -17,7 +17,7 @@ async def flash_medusa(flash_medusa_handle): @pytest.mark.private async def test_flash_medusa_simple(flash_medusa, response_snapshot): response = await flash_medusa.generate( - "Test request", max_new_tokens=10, decoder_input_details=True + "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True ) assert response.details.generated_tokens == 10 @@ -28,7 +28,7 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot): @pytest.mark.private async def test_flash_medusa_all_params(flash_medusa, response_snapshot): response = await flash_medusa.generate( - "Test request", + "What is Deep Learning?", max_new_tokens=10, repetition_penalty=1.2, return_full_text=True, @@ -43,17 +43,17 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot): seed=0, ) - assert response.details.generated_tokens == 5 + assert response.details.generated_tokens == 10 assert response == response_snapshot @pytest.mark.asyncio @pytest.mark.private async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): - responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4) + responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" - assert responses[0].generated_text == ' for the "Get all users" endpoint.\n' + assert responses[0].generated_text == 'ep learning is a subset of machine learning that involves' assert responses == response_snapshot diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 82f38564..faa446d2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -232,7 +232,7 @@ class FlashCausalLMBatch(Batch): cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens) + max_length = max(max_length, input_length + max_new_tokens + speculative_length) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device @@ -479,6 +479,7 @@ class FlashCausalLMBatch(Batch): max_blocks = 0 max_length = 0 max_seqlen = 0 + speculative_length = 0 if batches[0].speculative_ids is None else batches[0].speculative_ids.shape[1] for b in batches: total_batch_size += len(b) total_slots += len(b.slots) @@ -489,6 +490,7 @@ class FlashCausalLMBatch(Batch): max_length, max( input_length + + speculative_length + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index a9f0374a..ffbbf40c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -16,7 +16,6 @@ from text_generation_server.utils.logits_process import ( from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor - class NextTokenChooser: def __init__( self, @@ -289,8 +288,6 @@ class HeterogeneousNextTokenChooser: indices.append(index) else: break - # if accepted > 1: - # import ipdb;ipdb.set_trace() accepted_ids.append(accepted) accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) next_ids = next_ids[indices]