From 71402ed4c7e4a630b611620e151dc3b3bd9fed26 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 28 Mar 2023 13:49:51 +0200 Subject: [PATCH] wip --- router/Cargo.toml | 3 +- server/Makefile | 2 +- server/poetry.lock | 49 +- server/pyproject.toml | 1 + .../text_generation_server/models/__init__.py | 8 + .../custom_modeling/flash_llama_modeling.py | 605 ++++++++++++++++++ .../models/flash_llama.py | 290 +++++++++ 7 files changed, 955 insertions(+), 3 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_llama_modeling.py create mode 100644 server/text_generation_server/models/flash_llama.py diff --git a/router/Cargo.toml b/router/Cargo.toml index 77577eea..801c647b 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -30,7 +30,8 @@ reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" serde_json = "1.0.93" thiserror = "1.0.38" -tokenizers = "0.13.2" +#tokenizers = "0.13.2" +tokenizers = { git = "https://github.com/huggingface/tokenizers.git" } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.11" tower-http = { version = "0.3.5", features = ["cors"] } diff --git a/server/Makefile b/server/Makefile index d2a8cf7a..9a2b6f03 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,4 @@ -transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef +transformers_commit := 88bfd061f44f4d33e687c20a79856314d877b86d flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1 gen-server: diff --git a/server/poetry.lock b/server/poetry.lock index 89ad31e9..e04ae75c 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -517,6 +517,14 @@ tensorflow = ["tensorflow"] testing = ["h5py", "huggingface-hub", "numpy", "pytest", "pytest-benchmark", "setuptools-rust"] torch = ["torch"] +[[package]] +name = "sentencepiece" +version = "0.1.97" +description = "SentencePiece python wrapper" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "setuptools" version = "67.4.0" @@ -630,7 +638,7 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "521dc9f3c283dc56f7d2e2f96759919ff27ab49ffd3ae7cd26317b209e7fa98d" +content-hash = "1963b706a875d5c9090fee0db7dc38be6770c6b037dc225b62c1d06537a5a69a" [metadata.files] accelerate = [ @@ -1116,6 +1124,45 @@ safetensors = [ {file = "safetensors-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:ba3dc236a2344b7feadc9868307f42ba5e4804c9d68a80a35aac831349b31f6f"}, {file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"}, ] +sentencepiece = [ + {file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6f249c8f1852893be86eae66b19d522c5fb30bbad4fe2d1b07f06fdc86e1907e"}, + {file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09e1bc53178de70c557a9ba4fece07364b4416ce3d36570726b3372b68aea135"}, + {file = "sentencepiece-0.1.97-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:667193c57fb48b238be7e3d7636cfc8da56cb5bac5559d8f0b647334e1175be8"}, + {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2780531985af79c6163f63d4f200fec8a28b70b6768d2c19f70d01568a4524e8"}, + {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:205050670c53ef9015e2a98cce3934bfbcf0aafaa14caa0c618dd5667bc217ee"}, + {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28b183dadef8e8b6b4645c1c20692d7be0a13ecc3ec1a07b3885c8905516675f"}, + {file = "sentencepiece-0.1.97-cp310-cp310-win32.whl", hash = "sha256:ee3c9dbd558d8d85bb1617087b86df6ea2b856a528669630ce6cedeb4353b823"}, + {file = "sentencepiece-0.1.97-cp310-cp310-win_amd64.whl", hash = "sha256:f7dc55379e2f7dee86537180283db2e5f8418c6825fdd2fe436c724eb5604c05"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ba1b4154f9144c5a7528b00aff5cffaa1a896a1c6ca53ca78b6e74cd2dae5244"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac3d90aee5581e55d029d124ac11b6ae2fbae0817863b664b2f2302e966ababb"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c27400f1ac46518a01c87cb7703650e4e48728649feb115d2e3f1102a946a42"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6e12a166eba75994ca749aadc4a5056b91b31405f805d6de6e8914cc9741c60"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-win32.whl", hash = "sha256:ed85dff5c0a9b3dd1a414c7e1119f2a19e863fc3f81da525bf7f885ebc883de0"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-win_amd64.whl", hash = "sha256:91a19ab6f40ffbae6d6127119953d2c6a85e93d734953dbc8629fde0d21ace66"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bae580e4a35a9314ff49561ac7c06574fe6afc71b821ed6bb00534e571458156"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad7262e7530c683b186672b5dd0082f82719a50a500a8cfbc4bbd7cde5bff8c"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:620cee35279720016735a7c7103cddbd9b84fe5e2f098bd5e673834d69fee2b8"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93b921b59914c0ec6697e8c6d5e6b44d99d1298fb1a0af56980a79ade0540c19"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-win32.whl", hash = "sha256:9b9a4c44a31d5f47616e9568dcf31e029b0bfa776e0a252c0b59247881598b09"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-win_amd64.whl", hash = "sha256:f31533cdacced56219e239d3459a003ece35116920dd64b2309d4ad047b77644"}, + {file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:7d643c01d1cad13b9206a276bbe5bc1a468e3d7cf6a26bde7783f945277f859d"}, + {file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:542f1985b1ee279a92bef7740ec0781452372028ce01e15aa88df3228b197ba3"}, + {file = "sentencepiece-0.1.97-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93701da21fea906dd244bf88cdbe640385a89c45d3c1812b76dbadf8782cdbcd"}, + {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51514047b964047b7fadb480d88a5e0f72c02f6ca1ba96258fbbc6e79274a94"}, + {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3ae2e9b7a5b6f2aa64ec9240b0c185dabe597d0e787dc4344acfbaef1ffe0b2"}, + {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:923ee4af16dbae1f2ab358ed09f8a0eb89e40a8198a8b343bf54181482342721"}, + {file = "sentencepiece-0.1.97-cp38-cp38-win32.whl", hash = "sha256:fa6f2b88850b5fae3a05053658824cf9f147c8e3c3b40eb64539a976c83d8a24"}, + {file = "sentencepiece-0.1.97-cp38-cp38-win_amd64.whl", hash = "sha256:5137ff0d0b1cc574751d178650ef800ff8d90bf21eb9f71e9567d4a0548940a5"}, + {file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f92876271a10494671431ad955bff2d6f8ea59baaf957f5ae5946aff56dfcb90"}, + {file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:35c227b6d55e473033db7e0ecc51b1e99e6ed7607cc08602fb5768132543c81d"}, + {file = "sentencepiece-0.1.97-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1706a8a8188f7b3d4b7922db9bb00c64c4e16ee68ab4caaae79f55b3e18748c7"}, + {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce61efc1862ccb18856c4aabbd930e13d5bfbb4b09b4f111081ac53a9dc62275"}, + {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a78c03800ef9f02d320e0159f5768b15357f3e9ebea545c9c4ba7928ba8ba254"}, + {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753b8088fd685ee787d9f54c84275ab347de558c7c4ebc6accb4c35bf7776f20"}, + {file = "sentencepiece-0.1.97-cp39-cp39-win32.whl", hash = "sha256:24306fd86031c17a1a6ae92671e76a350390a3140a65620bc2843dad7db24e2a"}, + {file = "sentencepiece-0.1.97-cp39-cp39-win_amd64.whl", hash = "sha256:c6641d0b7acec61fde5881ea6ebe098c169557ac9aa3bdabdf124eab5a5592bb"}, + {file = "sentencepiece-0.1.97.tar.gz", hash = "sha256:c901305e0a710bbcd296f66d79e96f744e6e175b29812bd5178318437d4e1f6c"}, +] setuptools = [ {file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"}, {file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index e9dc624c..14459be0 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -23,6 +23,7 @@ opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" +sentencepiece = "^0.1.97" [tool.poetry.extras] bnb = ["bitsandbytes"] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bcaf6ec1..577f94b8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -19,6 +19,7 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_santacoder import FlashSantacoder + from text_generation_server.models.flash_llama import FlashLlama FLASH_ATTENTION = ( torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1 @@ -92,6 +93,13 @@ def get_model( neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM return neox_cls(model_id, revision, quantize=quantize) + if model_type == "llama": + if sharded: + raise NotImplementedError + else: + llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM + return llama_cls(model_id, revision, quantize=quantize) + if model_type == "t5": if sharded: return T5Sharded(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py new file mode 100644 index 00000000..38d9fa24 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -0,0 +1,605 @@ +import torch +import torch.distributed + +from torch.nn import functional as F + +from torch import nn +from transformers.activations import ACT2FN + +# Flash attention imports +import rotary_emb +import flash_attn_cuda +import dropout_layer_norm + +from flash_attn.layers.rotary import RotaryEmbedding + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv + + +class FlashLlamaAttention(torch.nn.Module): + def __init__( + self, + num_heads, + hidden_size, + process_group=None, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) + self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) + else: + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear( + hidden_size, + 3 * hidden_size, + bias=False, + process_group=process_group, + ) + self.o_proj = TensorParallelRowLinear( + hidden_size, + hidden_size, + bias=False, + process_group=process_group, + ) + + def shuffle_qkv_dims(self): + """Swap dims to avoid an additional permute""" + self.query_key_value.weight = torch.nn.Parameter( + self.query_key_value.weight.view( + self.num_heads, 3, self.head_size, self.hidden_size + ) + .permute(1, 0, 2, 3) + .reshape(-1, self.hidden_size) + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(qkv[:, 0]) + # flash attention + flash_attn_cuda.fwd( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + query = qkv_rot[:, 0] + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + layer_past[:, 0], + layer_past[:, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class LlamaMLP(nn.Module): + def __init__(self, act, hidden_size, intermediate_size, process_group=None): + super().__init__() + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + self.intermediate_size = intermediate_size + + if process_group is None: + # Fuse gate and up proj + self.gate_up_proj = FastLinear( + hidden_size, 2 * intermediate_size, bias=False + ) + self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) + else: + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + process_group=process_group, + ) + self.down_proj = TensorParallelRowLinear( + intermediate_size, + hidden_size, + bias=False, + process_group=process_group, + reduce=True, + ) + self.process_group = process_group + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashLlamaLayer(nn.Module): + def __init__( + self, + num_heads, + act, + hidden_size, + intermediate_size, + rms_norm_eps, + process_group=None, + ): + super().__init__() + + self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) + self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) + + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + # faster input rms norm + hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.input_layernorm.weight, + None, + None, + None, + None, + None, + 0.0, + self.input_layernorm.variance_epsilon, + 1.0, + 0, + None, + False, + True, + ) + + hidden_states = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + # faster post attention rms norm + hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.post_attention_layernorm.weight, + None, + None, + None, + None, + None, + 0.0, + self.post_attention_layernorm.variance_epsilon, + 1.0, + 0, + None, + False, + True, + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashLlamaModel(torch.nn.Module): + def __init__(self, config, process_group=None): + super(FlashLlamaModel, self).__init__() + self.config = config + + self.tp_embeddings = False + if process_group is not None: + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True + + if self.tp_embeddings: + self.embed_tokens = TensorParallelEmbedding( + config.vocab_size, config.hidden_size, process_group=process_group + ) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layers = nn.ModuleList( + [ + FlashLlamaLayer( + config.num_attention_heads, + config.hidden_act, + config.hidden_size, + config.intermediate_size, + config.rms_norm_eps, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def post_load_weights(self): + if isinstance(self.embed_tokens, TensorParallelEmbedding): + self.embed_tokens.add_null_idx() + for layer in self.layers: + layer: FlashLlamaLayer + layer.self_attn.shuffle_qkv_dims() + layer.self_attn.query_key_value.transpose_weight() + layer.self_attn.o_proj.transpose_weight() + layer.mlp.gate_up_proj.transpose_weight() + layer.mlp.down_proj.transpose_weight() + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states = self.embed_tokens(input_ids) + + # Prefill + if past_key_values is None: + # Create past tensor + past_key_values = hidden_states.new_empty( + ( + len(self.layers), + len(hidden_states), + 2, + self.num_heads, + self.head_size, + ) + ) + layer_past_present_indices = None + cu_seqlens_q = None + # Decode + else: + # Create indices from cumulative sequence lengths + layer_past_present_indices = cu_seqlens[1:] - 1 + cu_seqlens_q = torch.arange( + cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device + ) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + past_key_values[i], + layer_past_present_indices, + cu_seqlens_q, + ) + + # Faster final layer norm + hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.norm.weight, + None, + None, + None, + None, + None, + 0.0, + self.norm.variance_epsilon, + 1.0, + 0, + None, + False, + True, + ) + + return hidden_states, past_key_values + + +class FlashLlamaForCausalLM(torch.nn.Module): + def __init__(self, config, process_group=None): + super().__init__() + + self.model = FlashLlamaModel(config, process_group) + + if self.model.tp_embeddings: + self.lm_head = FastLinear( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) + + def post_load_weights(self): + self.model.post_load_weights() + self.lm_head.transpose_weight() + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states, present = self.model( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) + return self.lm_head(hidden_states), present diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py new file mode 100644 index 00000000..0403c9f6 --- /dev/null +++ b/server/text_generation_server/models/flash_llama.py @@ -0,0 +1,290 @@ +import torch +import torch.distributed + +from accelerate import init_empty_weights +from opentelemetry import trace +from pathlib import Path +from safetensors import safe_open +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, Tuple, List + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, + weight_hub_files, + LocalEntryNotFoundError, +) + +tracer = trace.get_tracer(__name__) + + +class FlashLlama(FlashCausalLM): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashCausalLM is only available on GPU") + + if quantize: + raise NotImplementedError("FlashCausalLM does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + + try: + filenames = weight_files(model_id, revision, ".bin") + # Local files not found + except LocalEntryNotFoundError: + hub_files = weight_hub_files(model_id, revision, ".bin") + filenames = download_weights(hub_files, model_id, revision) + + with init_empty_weights(): + model = FlashLlamaForCausalLM(config) + + self.load_weights( + model, + filenames, + ) + self.model = model.eval().to(device).to(dtype) + + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[Path], + ): + final_state_dict = {} + for filename in filenames: + state_dict = torch.load(filename, map_location="cpu") + for key, value in state_dict.items(): + layer_name = ".".join(key.split(".")[:4]) + if "q_proj" in key: + final_key = layer_name + ".query_key_value.weight" + if final_key not in final_state_dict: + final_state_dict[final_key] = value.new_empty( + (value.shape[0] * 3, value.shape[1]) + ) + final_state_dict[final_key][: value.shape[0]] = value + elif "k_proj" in key: + final_key = layer_name + ".query_key_value.weight" + if final_key not in final_state_dict: + final_state_dict[final_key] = value.new_empty( + (value.shape[0] * 3, value.shape[1]) + ) + final_state_dict[final_key][ + value.shape[0] : value.shape[0] * 2 + ] = value + elif "v_proj" in key: + final_key = layer_name + ".query_key_value.weight" + if final_key not in final_state_dict: + final_state_dict[final_key] = value.new_empty( + (value.shape[0] * 3, value.shape[1]) + ) + final_state_dict[final_key][value.shape[0] * 2 :] = value + elif "gate_proj" in key: + final_key = layer_name + ".gate_up_proj.weight" + if final_key not in final_state_dict: + final_state_dict[final_key] = value.new_empty( + (value.shape[0] * 2, value.shape[1]) + ) + final_state_dict[final_key][: value.shape[0]] = value + elif "up_proj" in key: + final_key = layer_name + ".gate_up_proj.weight" + if final_key not in final_state_dict: + final_state_dict[final_key] = value.new_empty( + (value.shape[0] * 2, value.shape[1]) + ) + final_state_dict[final_key][value.shape[0] :] = value + else: + final_state_dict[key] = value + del state_dict + + parameters = dict(model.named_parameters()) + for key, value in final_state_dict.items(): + current_parameter_tensor = parameters.get(key, None) + module_name, param_name = key.rsplit(".", 1) + module = model.get_submodule(module_name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != value.shape + ): + raise ValueError( + f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + + value = value.contiguous() + + if current_parameter_tensor is not None: + module._parameters[param_name] = value + else: + module._buffers[param_name] = value + + model.post_load_weights() + + +class FlashLlamaSharded(FlashLlama): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + if quantize: + raise NotImplementedError("FlashLlama does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = FlashGPTNeoXForCausalLM(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + model.post_load_weights() + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: int, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.model.gpt_neox.tp_embeddings: + logits, present = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_s=max_s, + past_key_values=past_key_values, + ) + + # Logits are sharded, so we need to gather them + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) + + return world_logits, present + # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard + else: + return super(FlashLlamaSharded, self).forward( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + )