Merge branch 'main' into feat/falcon-11b

This commit is contained in:
Nilabhra 2024-05-14 11:40:31 +04:00
commit e45ede2cde
2 changed files with 12 additions and 10 deletions

View File

@ -24,16 +24,19 @@ import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
)
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils import flash_attn, paged_attention
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
def load_attention(config, prefix, weights):

View File

@ -6,14 +6,13 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils import flash_attn, paged_attention
from text_generation_server.utils.layers import (
FastLayerNorm,
PositionRotaryEmbedding,
SpeculativeHead,
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.layernorm import (