diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5b32b22e..2af5834b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,12 +33,12 @@ import flash_attn_cuda import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding - -HAS_BITS_AND_BYTES = True -try: - from bitsandbytes.nn import Linear8bitLt -except ImportError as e: - HAS_BITS_AND_BYTES = False +from text_generation_server.utils.layers import ( + FastLinear, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, +) class LlamaRMSNorm(nn.Module): @@ -91,178 +91,6 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res -class FastLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - self.quantized = False - self.bnb_linear = None - - def prepare_weights(self, quantize: bool = False): - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - self.quantized = True - self.bnb_linear = Linear8bitLt( - self.in_features, - self.out_features, - has_fp16_weights=False, - threshold=6.0, - bias=False, - ) - # Copy data to bnb_linear - self.bnb_linear.weight.data = self.weight.data - if self.bias is not None: - self.bnb_linear.bias = nn.Parameter(self.bias) - - # Delete reference to data - self.weight = None - self.bias = None - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - self.weight = nn.Parameter(self.weight.T) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.quantized: - return self.bnb_linear(input) - else: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) - - -class TensorParallelColumnLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - - self.original_num_embeddings = num_embeddings - - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size - - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size - - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) - - def add_null_idx(self): - """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # default all out of bounds values to `self.null_idx` that will then be mapped to 0 - # translate for [0, self.max_id - self.min_id[ - input = torch.where( - (self.min_id > input) | (input >= self.max_id), - self.null_idx, - input - self.min_id, - ) - out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) - return out - - class PositionRotaryEmbedding(RotaryEmbedding): def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index e7b878c0..fdeb4084 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -35,260 +35,14 @@ import flash_attn_cuda import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding - -HAS_BITS_AND_BYTES = True -try: - from bitsandbytes.nn import Linear8bitLt -except ImportError as e: - HAS_BITS_AND_BYTES = False - - -class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states - - return normed_hidden_states, residual - - -class FastLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - self.quantized = False - self.bnb_linear = None - - def prepare_weights(self, quantize: Optional[str] = None): - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - self.quantized = True - self.bnb_linear = Linear8bitLt( - self.in_features, - self.out_features, - has_fp16_weights=False, - threshold=6.0, - bias=False, - ) - # Copy data to bnb_linear - self.bnb_linear.weight.data = self.weight.data - if self.bias is not None: - self.bnb_linear.bias = nn.Parameter(self.bias) - - # Delete reference to data - self.weight = None - self.bias = None - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - self.weight = nn.Parameter(self.weight.T) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.quantized: - return self.bnb_linear(input) - else: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) - - -class TensorParallelColumnLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - - self.original_num_embeddings = num_embeddings - - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size - - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size - - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) - - def add_null_idx(self): - """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # default all out of bounds values to `self.null_idx` that will then be mapped to 0 - # translate for [0, self.max_id - self.min_id[ - input = torch.where( - (self.min_id > input) | (input >= self.max_id), - self.null_idx, - input - self.min_id, - ) - out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) - return out - - -class PositionRotaryEmbedding(RotaryEmbedding): - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - cos = torch.index_select(self._cos_cached, 0, position_ids) - sin = torch.index_select(self._sin_cached, 0, position_ids) - return cos.unsqueeze(1), sin.unsqueeze(1) - - def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return qkv +from text_generation_server.utils.layers import ( + FastLinear, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + FastLayerNorm, + PositionRotaryEmbedding, +) class FlashNeoxAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 23c3ea28..9451b01a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,224 +9,13 @@ from typing import Optional # Flash attention imports import flash_attn_cuda -import dropout_layer_norm - -HAS_BITS_AND_BYTES = True -try: - from bitsandbytes.nn import Linear8bitLt -except ImportError as e: - HAS_BITS_AND_BYTES = False - - -class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states - - return normed_hidden_states, residual - - -class FastLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - self.quantized = False - self.bnb_linear = None - - def prepare_weights(self, quantize: Optional[str] = None): - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - self.quantized = True - self.bnb_linear = Linear8bitLt( - self.in_features, - self.out_features, - has_fp16_weights=False, - threshold=6.0, - bias=False, - ) - # Copy data to bnb_linear - self.bnb_linear.weight.data = self.weight.data - if self.bias is not None: - self.bnb_linear.bias = nn.Parameter(self.bias) - - # Delete reference to data - self.weight = None - self.bias = None - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - self.weight = nn.Parameter(self.weight.T) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.quantized: - return self.bnb_linear(input) - else: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) - - -class TensorParallelColumnLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - reduce=True, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - self.reduce = reduce - - self.original_num_embeddings = num_embeddings - - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size - - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size - - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) - - def add_null_idx(self): - """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # default all out of bounds values to `self.null_idx` that will then be mapped to 0 - # translate for [0, self.max_id - self.min_id[ - input = torch.where( - (self.min_id > input) | (input >= self.max_id), - self.null_idx, - input - self.min_id, - ) - out = super().forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) - return out +from text_generation_server.utils.layers import ( + FastLinear, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + FastLayerNorm, +) class FlashMQAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index b5e7710d..69ef8c87 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -10,24 +10,18 @@ from transformers import ( AutoModelForSeq2SeqLM, AutoConfig, ) -from transformers.models.t5.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) from text_generation_server.models import Seq2SeqLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) - -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False +from text_generation_server.utils.layers import ( + FastLinear, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, +) class T5Sharded(Seq2SeqLM): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py new file mode 100644 index 00000000..cbaf6d00 --- /dev/null +++ b/server/text_generation_server/utils/layers.py @@ -0,0 +1,266 @@ +import torch + +from torch import nn + +HAS_BITS_AND_BYTES = True +try: + from bitsandbytes.nn import Linear8bitLt +except ImportError as e: + HAS_BITS_AND_BYTES = False + + +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + self.quantized = False + self.bnb_linear = None + + def prepare_weights(self, quantize: bool = False): + if quantize == "bitsandbytes": + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + self.quantized = True + self.bnb_linear = Linear8bitLt( + self.in_features, + self.out_features, + has_fp16_weights=False, + threshold=6.0, + bias=False, + ) + # Copy data to bnb_linear + self.bnb_linear.weight.data = self.weight.data + if self.bias is not None: + self.bnb_linear.bias = nn.Parameter(self.bias) + + # Delete reference to data + self.weight = None + self.bias = None + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + self.weight = nn.Parameter(self.weight.T) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.quantized: + return self.bnb_linear(input) + else: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + + +try: + from flash_attn.layers.rotary import RotaryEmbedding + import rotary_emb + + class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin( + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + ): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv + +except ImportError: + pass