From 011887f15c8b06447df3449b1c1421c9d43bb102 Mon Sep 17 00:00:00 2001 From: Nilabhra Date: Tue, 14 May 2024 11:00:45 +0400 Subject: [PATCH] chore: removed unused import. --- .../custom_modeling/flash_llama_modeling.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c52ae9ef..afcaac4b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,25 +18,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple + import torch import torch.distributed - from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( - TensorParallelRowLinear, + SpeculativeHead, TensorParallelColumnLinear, TensorParallelEmbedding, - SpeculativeHead, - get_linear, + TensorParallelRowLinear, ) +from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding -from text_generation_server.layers.layernorm import ( - FastRMSNorm, -) +from text_generation_server.utils import flash_attn, paged_attention def load_attention(config, prefix, weights):