diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 69ef8c87..281f9503 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -18,11 +18,20 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.layers import ( FastLinear, +) +from transformers.models.t5.parallel_layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, ) +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params +except ImportError as e: + HAS_BITS_AND_BYTES = False + class T5Sharded(Seq2SeqLM): def __init__( diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4c89e54e..3383bf4b 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,7 +1,6 @@ import torch from torch import nn -import dropout_layer_norm HAS_BITS_AND_BYTES = True try: @@ -182,40 +181,46 @@ class TensorParallelEmbedding(nn.Embedding): return out -class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states +try: + import dropout_layer_norm - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual residual = hidden_states - return normed_hidden_states, residual + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +except ImportError: + pass try: