From 79f9afba90652f8ecf2d2b43931a8d83e986fded Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 4 Dec 2023 14:36:21 +0000 Subject: [PATCH] Needed to regenerate params tests + fix simple tests --- .../test_flash_llama_gptq_all_params.json | 74 ++-- .../test_flash_medusa_all_params.json | 8 +- .../test_flash_medusa_load.json | 318 ++++++++---------- .../test_flash_medusa_simple.json | 14 +- .../test_flash_mistral_all_params.json | 61 ++-- .../test_flash_starcoder_default_params.json | 154 ++++----- ...t_flash_starcoder_gptq_default_params.json | 24 +- integration-tests/models/test_flash_medusa.py | 2 +- server/tests/models/test_bloom.py | 4 +- server/tests/models/test_causal_lm.py | 4 +- server/tests/models/test_seq2seq_lm.py | 4 +- 11 files changed, 300 insertions(+), 367 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json index 02713a00..8d705c59 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -11,12 +11,12 @@ }, { "id": 4321, - "logprob": -9.6015625, + "logprob": -9.5625, "text": "Test" }, { "id": 2009, - "logprob": -9.6640625, + "logprob": -9.6796875, "text": "request" } ], @@ -24,15 +24,33 @@ "tokens": [ { "id": 29899, - "logprob": -1.1640625, + "logprob": -1.1972656, "special": false, "text": "-" }, { - "id": 1454, - "logprob": -0.07543945, + "id": 29896, + "logprob": -1.1621094, "special": false, - "text": "for" + "text": "1" + }, + { + "id": 13, + "logprob": -0.12768555, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.074279785, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": 0.0, + "special": false, + "text": " request" }, { "id": 29899, @@ -40,49 +58,31 @@ "special": false, "text": "-" }, - { - "id": 9342, - "logprob": 0.0, - "special": false, - "text": "comment" - }, - { - "id": 29901, - "logprob": 0.0, - "special": false, - "text": ":" - }, - { - "id": 396, - "logprob": -0.2956543, - "special": false, - "text": " #" - }, { "id": 29906, - "logprob": -0.52734375, + "logprob": -0.20141602, "special": false, "text": "2" }, { - "id": 29900, - "logprob": -0.6899414, - "special": false, - "text": "0" - }, - { - "id": 29896, + "id": 13, "logprob": 0.0, "special": false, - "text": "1" + "text": "\n" }, { - "id": 29946, - "logprob": -1.5068359, + "id": 3057, + "logprob": -0.6611328, "special": false, - "text": "4" + "text": "Test" + }, + { + "id": 2009, + "logprob": 0.0, + "special": false, + "text": " request" } ] }, - "generated_text": "Test request-for-comment: #2014" + "generated_text": "Test request-1\nTest request-2\nTest request" } diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json index 4dd815b3..05b9a365 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -91,14 +91,8 @@ "logprob": 0.0, "special": false, "text": " a" - }, - { - "id": 11306, - "logprob": -0.5488281, - "special": false, - "text": " subset" } ] }, - "generated_text": "What is Deep Learning?\nDeep learning can be thought of as a subset" + "generated_text": "What is Deep Learning?\nDeep learning can be thought of as a" } diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json index 6698f5f4..413af1d7 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json @@ -17,7 +17,7 @@ }, { "id": 338, - "logprob": -1.5498047, + "logprob": -1.5488281, "text": "is" }, { @@ -27,12 +27,12 @@ }, { "id": 29257, - "logprob": -1.2734375, + "logprob": -1.2753906, "text": "Learning" }, { "id": 29973, - "logprob": -0.48217773, + "logprob": -0.48046875, "text": "?" } ], @@ -40,19 +40,122 @@ "tokens": [ { "id": 13, - "logprob": -1.1875, + "logprob": -1.1845703, "special": false, "text": "\n" }, { "id": 2772, - "logprob": -0.5708008, + "logprob": -0.5727539, "special": false, "text": "De" }, { "id": 1022, - "logprob": -0.00010931492, + "logprob": -0.00010967255, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.04510498, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.00020992756, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.0046539307, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025844574, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1826172, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, "special": false, "text": "ep" }, @@ -70,13 +173,13 @@ }, { "id": 263, - "logprob": -0.018310547, + "logprob": -0.018295288, "special": false, "text": " a" }, { "id": 11306, - "logprob": -0.46044922, + "logprob": -0.45922852, "special": false, "text": " subset" }, @@ -94,25 +197,13 @@ }, { "id": 6509, - "logprob": -0.00025820732, + "logprob": -0.00025892258, "special": false, "text": " learning" - }, - { - "id": 393, - "logprob": -0.09185791, - "special": false, - "text": " that" - }, - { - "id": 20789, - "logprob": -0.4951172, - "special": false, - "text": " involves" } ] }, - "generated_text": "\nDeep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning" }, { "details": { @@ -127,12 +218,12 @@ }, { "id": 1724, - "logprob": -10.7421875, + "logprob": -10.734375, "text": "What" }, { "id": 338, - "logprob": -1.5498047, + "logprob": -1.5488281, "text": "is" }, { @@ -155,19 +246,19 @@ "tokens": [ { "id": 13, - "logprob": -1.1835938, + "logprob": -1.1826172, "special": false, "text": "\n" }, { "id": 2772, - "logprob": -0.57470703, + "logprob": -0.56689453, "special": false, "text": "De" }, { "id": 1022, - "logprob": -0.00010788441, + "logprob": -0.000108003616, "special": false, "text": "ep" }, @@ -179,13 +270,13 @@ }, { "id": 338, - "logprob": -0.04510498, + "logprob": -0.044433594, "special": false, "text": " is" }, { "id": 263, - "logprob": -0.018585205, + "logprob": -0.018295288, "special": false, "text": " a" }, @@ -197,37 +288,25 @@ }, { "id": 310, - "logprob": -0.00021457672, + "logprob": -0.0002104044, "special": false, "text": " of" }, { "id": 4933, - "logprob": -0.004776001, + "logprob": -0.004711151, "special": false, "text": " machine" }, { "id": 6509, - "logprob": -0.0002593994, + "logprob": -0.00025892258, "special": false, "text": " learning" - }, - { - "id": 393, - "logprob": -0.091918945, - "special": false, - "text": " that" - }, - { - "id": 20789, - "logprob": -0.50097656, - "special": false, - "text": " involves" } ] }, - "generated_text": "\nDeep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning" }, { "details": { @@ -242,12 +321,12 @@ }, { "id": 1724, - "logprob": -10.7421875, + "logprob": -10.734375, "text": "What" }, { "id": 338, - "logprob": -1.5498047, + "logprob": -1.5488281, "text": "is" }, { @@ -270,19 +349,19 @@ "tokens": [ { "id": 13, - "logprob": -1.1835938, + "logprob": -1.1826172, "special": false, "text": "\n" }, { "id": 2772, - "logprob": -0.57470703, + "logprob": -0.56689453, "special": false, "text": "De" }, { "id": 1022, - "logprob": -0.00010788441, + "logprob": -0.000108003616, "special": false, "text": "ep" }, @@ -294,13 +373,13 @@ }, { "id": 338, - "logprob": -0.04510498, + "logprob": -0.044433594, "special": false, "text": " is" }, { "id": 263, - "logprob": -0.018585205, + "logprob": -0.018295288, "special": false, "text": " a" }, @@ -312,151 +391,24 @@ }, { "id": 310, - "logprob": -0.00021457672, + "logprob": -0.0002104044, "special": false, "text": " of" }, { "id": 4933, - "logprob": -0.004776001, + "logprob": -0.004711151, "special": false, "text": " machine" }, { "id": 6509, - "logprob": -0.0002593994, + "logprob": -0.00025892258, "special": false, "text": " learning" - }, - { - "id": 393, - "logprob": -0.091918945, - "special": false, - "text": " that" - }, - { - "id": 20789, - "logprob": -0.50097656, - "special": false, - "text": " involves" } ] }, - "generated_text": "\nDeep learning is a subset of machine learning that involves" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 1, - "logprob": null, - "text": "" - }, - { - "id": 1724, - "logprob": -10.7421875, - "text": "What" - }, - { - "id": 338, - "logprob": -1.5498047, - "text": "is" - }, - { - "id": 21784, - "logprob": -9.2890625, - "text": "Deep" - }, - { - "id": 29257, - "logprob": -1.2724609, - "text": "Learning" - }, - { - "id": 29973, - "logprob": -0.47729492, - "text": "?" - } - ], - "seed": null, - "tokens": [ - { - "id": 13, - "logprob": -1.1835938, - "special": false, - "text": "\n" - }, - { - "id": 2772, - "logprob": -0.57470703, - "special": false, - "text": "De" - }, - { - "id": 1022, - "logprob": -0.00010788441, - "special": false, - "text": "ep" - }, - { - "id": 6509, - "logprob": -0.1239624, - "special": false, - "text": " learning" - }, - { - "id": 338, - "logprob": -0.04510498, - "special": false, - "text": " is" - }, - { - "id": 263, - "logprob": -0.018585205, - "special": false, - "text": " a" - }, - { - "id": 11306, - "logprob": -0.45922852, - "special": false, - "text": " subset" - }, - { - "id": 310, - "logprob": -0.00021457672, - "special": false, - "text": " of" - }, - { - "id": 4933, - "logprob": -0.004776001, - "special": false, - "text": " machine" - }, - { - "id": 6509, - "logprob": -0.0002593994, - "special": false, - "text": " learning" - }, - { - "id": 393, - "logprob": -0.091918945, - "special": false, - "text": " that" - }, - { - "id": 20789, - "logprob": -0.50097656, - "special": false, - "text": " involves" - } - ] - }, - "generated_text": "\nDeep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json index cd3cb53a..15754b14 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json @@ -96,20 +96,8 @@ "logprob": -0.00026226044, "special": false, "text": " learning" - }, - { - "id": 393, - "logprob": -0.09161377, - "special": false, - "text": " that" - }, - { - "id": 20789, - "logprob": -0.49560547, - "special": false, - "text": " involves" } ] }, - "generated_text": "\nDeep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning" } diff --git a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json index c0dc6471..91ac444e 100644 --- a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json @@ -35,10 +35,10 @@ "text": " Let" }, { - "id": 332, - "logprob": -2.3359375, + "id": 268, + "logprob": -2.0234375, "special": false, - "text": " u" + "text": " s" }, { "id": 347, @@ -46,44 +46,43 @@ "special": false, "text": " be" }, - { - "id": 325, - "logprob": -1.0234375, - "special": false, - "text": " (" - }, - { - "id": 28734, - "logprob": -2.0292969, - "special": false, - "text": "0" - }, - { - "id": 648, - "logprob": -1.0439453, - "special": false, - "text": " +" - }, { "id": 28705, - "logprob": -0.24499512, + "logprob": -0.44458008, "special": false, "text": " " }, { - "id": 28770, - "logprob": -0.5073242, + "id": 28784, + "logprob": -0.94189453, "special": false, - "text": "3" + "text": "6" }, { - "id": 387, - "logprob": -1.5507812, + "id": 28748, + "logprob": 0.0, "special": false, - "text": " -" + "text": "/" + }, + { + "id": 6422, + "logprob": 0.0, + "special": false, + "text": "(-" + }, + { + "id": 28783, + "logprob": -0.2590332, + "special": false, + "text": "8" + }, + { + "id": 28731, + "logprob": -1.0039062, + "special": false, + "text": ")" } - ], - "top_tokens": null + ] }, - "generated_text": "Test request: Let u be (0 + 3 -" + "generated_text": "Test request: Let s be 6/(-8)" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json index 89e02c07..86d55a17 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -11,7 +11,7 @@ }, { "id": 1459, - "logprob": -5.6328125, + "logprob": -5.6289062, "text": " print" }, { @@ -21,7 +21,7 @@ }, { "id": 7656, - "logprob": -5.9882812, + "logprob": -5.9960938, "text": "hello" } ], @@ -59,13 +59,13 @@ }, { "id": 10896, - "logprob": -0.38549805, + "logprob": -0.3659668, "special": false, "text": " World" }, { "id": 657, - "logprob": -0.5229492, + "logprob": -0.49804688, "special": false, "text": "\")" }, @@ -113,7 +113,7 @@ }, { "id": 426, - "logprob": 0.0, + "logprob": -0.051635742, "special": false, "text": "name" }, @@ -148,10 +148,22 @@ "text": " print" }, { - "id": 440, - "logprob": -0.16027832, + "id": 26, + "logprob": -1.9101562, "special": false, - "text": "(\"" + "text": "(" + }, + { + "id": 88, + "logprob": 0.0, + "special": false, + "text": "f" + }, + { + "id": 20, + "logprob": 0.0, + "special": false, + "text": "\"" }, { "id": 8279, @@ -160,28 +172,22 @@ "text": "Hello" }, { - "id": 313, + "id": 301, "logprob": 0.0, "special": false, - "text": " \"" + "text": " {" }, { - "id": 474, + "id": 426, "logprob": 0.0, "special": false, - "text": " +" + "text": "name" }, { - "id": 636, + "id": 8474, "logprob": 0.0, "special": false, - "text": " name" - }, - { - "id": 27, - "logprob": 0.0, - "special": false, - "text": ")" + "text": "}\")" }, { "id": 203, @@ -286,10 +292,22 @@ "text": " print" }, { - "id": 440, + "id": 26, "logprob": 0.0, "special": false, - "text": "(\"" + "text": "(" + }, + { + "id": 88, + "logprob": 0.0, + "special": false, + "text": "f" + }, + { + "id": 20, + "logprob": 0.0, + "special": false, + "text": "\"" }, { "id": 8279, @@ -298,58 +316,40 @@ "text": "Hello" }, { - "id": 313, + "id": 301, "logprob": 0.0, "special": false, - "text": " \"" + "text": " {" }, { - "id": 474, + "id": 426, "logprob": 0.0, "special": false, - "text": " +" + "text": "name" }, { - "id": 636, + "id": 835, + "logprob": -0.4074707, + "special": false, + "text": "}," + }, + { + "id": 844, "logprob": 0.0, "special": false, - "text": " name" + "text": " you" }, { - "id": 474, + "id": 884, "logprob": 0.0, "special": false, - "text": " +" + "text": " are" }, { - "id": 313, - "logprob": -0.6328125, - "special": false, - "text": " \"" - }, - { - "id": 313, - "logprob": -1.7011719, - "special": false, - "text": " \"" - }, - { - "id": 474, + "id": 301, "logprob": 0.0, "special": false, - "text": " +" - }, - { - "id": 596, - "logprob": 0.0, - "special": false, - "text": " str" - }, - { - "id": 26, - "logprob": 0.0, - "special": false, - "text": "(" + "text": " {" }, { "id": 381, @@ -358,36 +358,36 @@ "text": "age" }, { - "id": 490, + "id": 111, "logprob": 0.0, "special": false, - "text": "))" + "text": "}" + }, + { + "id": 11274, + "logprob": 0.0, + "special": false, + "text": " years" + }, + { + "id": 3610, + "logprob": 0.0, + "special": false, + "text": " old" + }, + { + "id": 657, + "logprob": 0.0, + "special": false, + "text": "\")" }, { "id": 203, "logprob": 0.0, "special": false, "text": "\n" - }, - { - "id": 203, - "logprob": 0.0, - "special": false, - "text": "\n" - }, - { - "id": 589, - "logprob": 0.0, - "special": false, - "text": "def" - }, - { - "id": 1459, - "logprob": 0.0, - "special": false, - "text": " print" } ] }, - "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print" + "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(f\"Hello {name}\")\n\ndef print_hello_name_age(name, age):\n print(f\"Hello {name}, you are {age} years old\")\n" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 5598a2ad..3815a2ef 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -11,7 +11,7 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -9.015625, "text": " ge" }, { @@ -21,47 +21,47 @@ }, { "id": 81, - "logprob": -0.25976562, + "logprob": -0.25634766, "text": "_" }, { "id": 6009, - "logprob": -2.2148438, + "logprob": -2.203125, "text": "mean" }, { "id": 26, - "logprob": -0.3010254, + "logprob": -0.30322266, "text": "(" }, { "id": 62, - "logprob": -5.6757812, + "logprob": -5.6015625, "text": "L" }, { "id": 44, - "logprob": -3.0898438, + "logprob": -3.0878906, "text": ":" }, { "id": 1682, - "logprob": -0.6791992, + "logprob": -0.6826172, "text": " List" }, { "id": 77, - "logprob": -0.38891602, + "logprob": -0.38354492, "text": "[" }, { "id": 1808, - "logprob": -0.92041016, + "logprob": -0.9760742, "text": "float" }, { "id": 10794, - "logprob": -2.5390625, + "logprob": -2.5234375, "text": "]):" } ], @@ -81,7 +81,7 @@ }, { "id": 11665, - "logprob": -1.6005859, + "logprob": -1.0537109, "special": false, "text": " reduce" }, @@ -159,7 +159,7 @@ }, { "id": 203, - "logprob": -0.11968994, + "logprob": -0.10021973, "special": false, "text": "\n" }, diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index e9dcf6d9..003409b0 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -54,6 +54,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" - assert responses[0].generated_text == '\nDeep learning is a subset of machine learning that involves' + assert responses[0].generated_text == '\nDeep learning is a subset of machine learning' assert responses == response_snapshot diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 71013cb6..1990ef8b 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -133,8 +133,8 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 10264 for generation in generations]) - assert all([generation.token_text == "Test" for generation in generations]) + assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids]) + assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts]) assert generations[0].request_id == 0 diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 0f9dab2c..f105ce6f 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -129,8 +129,8 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 13 for generation in generations]) - assert all([generation.token_text == "." for generation in generations]) + assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids]) + assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts]) assert generations[0].request_id == 0 diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 299340f8..d553067e 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -151,8 +151,8 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 259 for generation in generations]) - assert all([generation.token_text == " " for generation in generations]) + assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids]) + assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts]) assert generations[0].request_id == 0