mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix flash models
This commit is contained in:
parent
a0abfa278e
commit
391b80c0f4
@ -21,8 +21,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
@ -32,7 +30,6 @@ from typing import Optional
|
|||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -128,6 +128,7 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||||||
num_embeddings,
|
num_embeddings,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
process_group: torch.distributed.ProcessGroup,
|
process_group: torch.distributed.ProcessGroup,
|
||||||
|
reduce=True,
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
norm_type=2.0,
|
norm_type=2.0,
|
||||||
@ -137,6 +138,7 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
|
self.reduce = reduce
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
@ -179,7 +181,8 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||||||
input - self.min_id,
|
input - self.min_id,
|
||||||
)
|
)
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if self.reduce:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user