Merge branch 'main' into feat/falcon-11b

This commit is contained in:
Nilabhra 2024-05-14 11:40:31 +04:00
commit e45ede2cde
2 changed files with 12 additions and 10 deletions

View File

@ -24,16 +24,19 @@ import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
SpeculativeHead, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, SpeculativeHead,
) )
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils import flash_attn, paged_attention from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):

View File

@ -6,14 +6,13 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils import flash_attn, paged_attention from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.flash_attn import attention
FastLayerNorm, from text_generation_server.layers import (
PositionRotaryEmbedding, TensorParallelRowLinear,
SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (