mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fixing layer imports (for isinstance compat).
This commit is contained in:
parent
edc9ce9beb
commit
42d8efcb04
@ -18,11 +18,20 @@ from text_generation_server.utils import (
|
||||
)
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
)
|
||||
from transformers.models.t5.parallel_layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except ImportError as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class T5Sharded(Seq2SeqLM):
|
||||
def __init__(
|
||||
|
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
import dropout_layer_norm
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
@ -182,7 +181,10 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||
return out
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
try:
|
||||
import dropout_layer_norm
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
@ -217,6 +219,9 @@ class FastLayerNorm(nn.LayerNorm):
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
|
Loading…
Reference in New Issue
Block a user