mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into feat/falcon-11b
This commit is contained in:
commit
e45ede2cde
@ -24,16 +24,19 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils import flash_attn, paged_attention
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -6,14 +6,13 @@ from torch import nn
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from text_generation_server.utils import flash_attn, paged_attention
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.flash_attn import attention
|
||||||
FastLayerNorm,
|
from text_generation_server.layers import (
|
||||||
PositionRotaryEmbedding,
|
TensorParallelRowLinear,
|
||||||
SpeculativeHead,
|
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
|
Loading…
Reference in New Issue
Block a user