diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 1a4ad551..2c6b8da6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -21,8 +21,6 @@ import torch import torch.distributed -from torch.nn import functional as F - from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel @@ -32,7 +30,6 @@ from typing import Optional # Flash attention imports import flash_attn_cuda -from flash_attn.layers.rotary import RotaryEmbedding from text_generation_server.utils.layers import ( FastLinear, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 7a301c1f..9bded805 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,8 +1,6 @@ import torch import torch.distributed -import torch.nn.functional as F - from torch import nn from transformers.activations import ACT2FN from typing import Optional diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 3386bc7d..7605639d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -128,6 +128,7 @@ class TensorParallelEmbedding(nn.Embedding): num_embeddings, embedding_dim, process_group: torch.distributed.ProcessGroup, + reduce=True, padding_idx=None, max_norm=None, norm_type=2.0, @@ -137,6 +138,7 @@ class TensorParallelEmbedding(nn.Embedding): device=None, dtype=None, ): + self.reduce = reduce self.process_group = process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() @@ -179,7 +181,8 @@ class TensorParallelEmbedding(nn.Embedding): input - self.min_id, ) out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) return out