diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 82f1b719..c5f8f64e 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -209,6 +209,7 @@ def launcher(event_loop): num_shard: Optional[int] = None, quantize: Optional[str] = None, trust_remote_code: bool = False, + use_flash_attention: bool = True, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -240,6 +241,9 @@ def launcher(event_loop): env = os.environ env["LOG_LEVEL"] = "info,text_generation_router=debug" + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + with subprocess.Popen( args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env ) as process: @@ -260,6 +264,7 @@ def launcher(event_loop): num_shard: Optional[int] = None, quantize: Optional[str] = None, trust_remote_code: bool = False, + use_flash_attention: bool = True, ): port = random.randint(8000, 10_000) @@ -287,6 +292,9 @@ def launcher(event_loop): gpu_count = num_shard if num_shard is not None else 1 env = {"LOG_LEVEL": "info,text_generation_router=debug"} + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + if HUGGING_FACE_HUB_TOKEN is not None: env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox.json b/integration-tests/models/__snapshots__/test_neox/test_neox.json new file mode 100644 index 00000000..2abc27e1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox.json @@ -0,0 +1,113 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1992188, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8984375, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.0976562, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14562988, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26733398, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.86279297, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.94921875, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1835938, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.074035645, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.86376953, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.2070312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4365234, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.109375, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -0.93408203, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.8808594, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" +} diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json new file mode 100644 index 00000000..f37f0d8e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json @@ -0,0 +1,454 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + } +] diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json new file mode 100644 index 00000000..15637cdb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json @@ -0,0 +1,654 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.4140625, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1621094, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.453125, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005393982, + "text": "e" + }, + { + "id": 13, + "logprob": -7.390625, + "text": "," + }, + { + "id": 285, + "logprob": -0.33691406, + "text": " and" + }, + { + "id": 752, + "logprob": -2.2207031, + "text": " what" + }, + { + "id": 434, + "logprob": -5.5976562, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.7661133, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.515625, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.3085938, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.3203125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1230469, + "text": " word" + }, + { + "id": 32, + "logprob": -0.00856781, + "text": "?" + }, + { + "id": 0, + "logprob": -2.4296875, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.1875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.64208984, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5839844, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.04989624, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0021305084, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.180172e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00092983246, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.08496094, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.13256836, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017059326, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.4921875, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.04904175e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0009560585, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.08557129, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12084961, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.4025879, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + } +] diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py new file mode 100644 index 00000000..eed70f80 --- /dev/null +++ b/integration-tests/models/test_neox.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_handle(launcher): + with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox(neox_handle): + await neox_handle.health(300) + return neox_handle.client + + +@pytest.mark.asyncio +async def test_neox(neox, response_snapshot): + response = await neox.generate( + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_neox_load(neox, generate_load, response_snapshot): + responses = await generate_load( + neox, + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py new file mode 100644 index 00000000..6ea97d81 --- /dev/null +++ b/integration-tests/models/test_neox_sharded.py @@ -0,0 +1,40 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_sharded_handle(launcher): + with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox_sharded(neox_sharded_handle): + await neox_sharded_handle.health(300) + return neox_sharded_handle.client + + +@pytest.mark.asyncio +async def test_neox(neox_sharded, response_snapshot): + response = await neox_sharded.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_neox_load(neox_sharded, generate_load, response_snapshot): + responses = await generate_load( + neox_sharded, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f0427c20..6a0f32a1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,3 +1,4 @@ +import os import torch from loguru import logger @@ -18,7 +19,7 @@ from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded try: - if torch.cuda.is_available(): + if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION").lower() == "false": major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 1e20a477..79fa1915 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -357,7 +357,7 @@ class GPTNeoXMLP(nn.Module): config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True ) self.dense_4h_to_h = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True ) def forward(self, hidden_states): @@ -430,6 +430,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps) + self.tp_world_size = weights.process_group.size() def forward( @@ -508,12 +509,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): past_key_values_length=past_key_values_length, ) - if hasattr(self, "tp_rank"): - assert self.num_attention_heads % self.tp_world_size == 0 - block_size = self.num_attention_heads // self.tp_world_size - causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) - else: - causal_mask = torch.repeat_interleave(causal_mask, self.num_attention_heads, dim=0) + assert self.num_attention_heads % self.tp_world_size == 0 + block_size = self.num_attention_heads // self.tp_world_size + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head