Adding integration for neox NON flash.

This commit is contained in:
Ubuntu 2023-06-06 16:19:27 +00:00
parent 644e0a65a3
commit 877d4d4aeb
8 changed files with 1320 additions and 8 deletions

View File

@ -209,6 +209,7 @@ def launcher(event_loop):
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
@ -240,6 +241,9 @@ def launcher(event_loop):
env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug"
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
) as process:
@ -260,6 +264,7 @@ def launcher(event_loop):
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
):
port = random.randint(8000, 10_000)
@ -287,6 +292,9 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1
env = {"LOG_LEVEL": "info,text_generation_router=debug"}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN

View File

@ -0,0 +1,113 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1992188,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8984375,
"text": " mood"
},
{
"id": 3063,
"logprob": -4.0976562,
"text": " today"
},
{
"id": 32,
"logprob": -0.14562988,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26733398,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.86279297,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.94921875,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1835938,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.074035645,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.86376953,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.2070312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4365234,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.109375,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -0.93408203,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.8808594,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}

View File

@ -0,0 +1,454 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}
]

View File

@ -0,0 +1,654 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.4140625,
"text": " is"
},
{
"id": 247,
"logprob": -2.1621094,
"text": " a"
},
{
"id": 1167,
"logprob": -5.453125,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005393982,
"text": "e"
},
{
"id": 13,
"logprob": -7.390625,
"text": ","
},
{
"id": 285,
"logprob": -0.33691406,
"text": " and"
},
{
"id": 752,
"logprob": -2.2207031,
"text": " what"
},
{
"id": 434,
"logprob": -5.5976562,
"text": "'s"
},
{
"id": 253,
"logprob": -0.7661133,
"text": " the"
},
{
"id": 2892,
"logprob": -6.515625,
"text": " history"
},
{
"id": 3212,
"logprob": -2.3085938,
"text": " behind"
},
{
"id": 436,
"logprob": -11.3203125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1230469,
"text": " word"
},
{
"id": 32,
"logprob": -0.00856781,
"text": "?"
},
{
"id": 0,
"logprob": -2.4296875,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.1875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.64208984,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.5839844,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.04989624,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0021305084,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.180172e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00092983246,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.08496094,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.13256836,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017059326,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.4921875,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.04904175e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0009560585,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.08557129,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12084961,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.01737976,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.4025879,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
}
]

View File

@ -0,0 +1,44 @@
import pytest
@pytest.fixture(scope="module")
def neox_handle(launcher):
with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox(neox_handle):
await neox_handle.health(300)
return neox_handle.client
@pytest.mark.asyncio
async def test_neox(neox, response_snapshot):
response = await neox.generate(
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_neox_load(neox, generate_load, response_snapshot):
responses = await generate_load(
neox,
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot

View File

@ -0,0 +1,40 @@
import pytest
@pytest.fixture(scope="module")
def neox_sharded_handle(launcher):
with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox_sharded(neox_sharded_handle):
await neox_sharded_handle.health(300)
return neox_sharded_handle.client
@pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot):
response = await neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
responses = await generate_load(
neox_sharded,
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -1,3 +1,4 @@
import os
import torch
from loguru import logger
@ -18,7 +19,7 @@ from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
try:
if torch.cuda.is_available():
if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION").lower() == "false":
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0

View File

@ -357,7 +357,7 @@ class GPTNeoXMLP(nn.Module):
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
@ -430,6 +430,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights)
self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps)
self.tp_world_size = weights.process_group.size()
def forward(
@ -508,12 +509,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
past_key_values_length=past_key_values_length,
)
if hasattr(self, "tp_rank"):
assert self.num_attention_heads % self.tp_world_size == 0
block_size = self.num_attention_heads // self.tp_world_size
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
else:
causal_mask = torch.repeat_interleave(causal_mask, self.num_attention_heads, dim=0)
assert self.num_attention_heads % self.tp_world_size == 0
block_size = self.num_attention_heads // self.tp_world_size
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head