Fixing layer imports (for isinstance compat).

This commit is contained in:
Nicolas Patry 2023-05-15 16:46:32 +02:00
parent edc9ce9beb
commit 42d8efcb04
2 changed files with 46 additions and 32 deletions

View File

@ -18,11 +18,20 @@ from text_generation_server.utils import (
) )
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
)
from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
) )
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except ImportError as e:
HAS_BITS_AND_BYTES = False
class T5Sharded(Seq2SeqLM): class T5Sharded(Seq2SeqLM):
def __init__( def __init__(

View File

@ -1,7 +1,6 @@
import torch import torch
from torch import nn from torch import nn
import dropout_layer_norm
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
try: try:
@ -182,40 +181,46 @@ class TensorParallelEmbedding(nn.Embedding):
return out return out
class FastLayerNorm(nn.LayerNorm): try:
def forward(self, hidden_states, residual=None): import dropout_layer_norm
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual class FastLayerNorm(nn.LayerNorm):
else: def forward(self, hidden_states, residual=None):
( if hidden_states.shape[-1] > 8192:
normed_hidden_states, if residual is not None:
residual, hidden_states += residual
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states residual = hidden_states
return normed_hidden_states, residual return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
except ImportError:
pass
try: try: