fix flash models

This commit is contained in:
OlivierDehaene 2023-05-15 18:12:50 +02:00
parent a0abfa278e
commit 391b80c0f4
3 changed files with 4 additions and 6 deletions

View File

@ -21,8 +21,6 @@
import torch import torch
import torch.distributed import torch.distributed
from torch.nn import functional as F
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
@ -32,7 +30,6 @@ from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
from flash_attn.layers.rotary import RotaryEmbedding
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,

View File

@ -1,8 +1,6 @@
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional from typing import Optional

View File

@ -128,6 +128,7 @@ class TensorParallelEmbedding(nn.Embedding):
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
process_group: torch.distributed.ProcessGroup, process_group: torch.distributed.ProcessGroup,
reduce=True,
padding_idx=None, padding_idx=None,
max_norm=None, max_norm=None,
norm_type=2.0, norm_type=2.0,
@ -137,6 +138,7 @@ class TensorParallelEmbedding(nn.Embedding):
device=None, device=None,
dtype=None, dtype=None,
): ):
self.reduce = reduce
self.process_group = process_group self.process_group = process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
@ -179,6 +181,7 @@ class TensorParallelEmbedding(nn.Embedding):
input - self.min_id, input - self.min_id,
) )
out = super().forward(input) out = super().forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out