mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Fixing layer imports (for isinstance compat).
This commit is contained in:
parent
edc9ce9beb
commit
42d8efcb04
@ -18,11 +18,20 @@ from text_generation_server.utils import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
|
)
|
||||||
|
from transformers.models.t5.parallel_layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
HAS_BITS_AND_BYTES = True
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from bitsandbytes.nn import Int8Params
|
||||||
|
except ImportError as e:
|
||||||
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
|
||||||
class T5Sharded(Seq2SeqLM):
|
class T5Sharded(Seq2SeqLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import dropout_layer_norm
|
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
@ -182,40 +181,46 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
try:
|
||||||
def forward(self, hidden_states, residual=None):
|
import dropout_layer_norm
|
||||||
if hidden_states.shape[-1] > 8192:
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
else:
|
def forward(self, hidden_states, residual=None):
|
||||||
(
|
if hidden_states.shape[-1] > 8192:
|
||||||
normed_hidden_states,
|
if residual is not None:
|
||||||
residual,
|
hidden_states += residual
|
||||||
*rest,
|
|
||||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
self.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
if residual is None:
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, residual
|
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
normed_hidden_states,
|
||||||
|
residual,
|
||||||
|
*rest,
|
||||||
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user