From 382bf59f4f3940763c2d6ceb955a8ee22d3e1b9a Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 24 Jul 2024 14:51:28 +0000 Subject: [PATCH] fix: lint and refactor import check and avoid model enum as global names --- server/text_generation_server/layers/bnb.py | 1 - server/text_generation_server/layers/fp8.py | 12 +++++++++--- server/text_generation_server/layers/rotary.py | 1 - server/text_generation_server/models/__init__.py | 7 +++---- .../models/custom_modeling/flash_llama_modeling.py | 1 - server/text_generation_server/utils/quantization.py | 1 - 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index aae2bd1a..791d9b6d 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from functools import lru_cache import bitsandbytes as bnb import torch diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 54550ee3..9c745647 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -12,17 +12,23 @@ from text_generation_server.utils.weights import ( Weights, ) from text_generation_server.utils.log import log_master, log_once +import importlib.util + FBGEMM_MM_AVAILABLE = False FBGEMM_DYN_AVAILABLE = False -try: - import fbgemm_gpu.experimental.gen_ai + +def is_fbgemm_gpu_available(): + return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None + + +if is_fbgemm_gpu_available(): if SYSTEM == "cuda": major, _ = torch.cuda.get_device_capability() FBGEMM_MM_AVAILABLE = major == 9 FBGEMM_DYN_AVAILABLE = major >= 8 -except (ImportError, ModuleNotFoundError): +else: log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 8a1e9261..8221068b 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -4,7 +4,6 @@ import torch from torch import nn # Inverse dim formula to find dim based on number of rotations -import math from text_generation_server.utils.import_utils import SYSTEM diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 6b222b76..6593229e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -298,10 +298,8 @@ class ModelType(enum.Enum): "multimodal": True, } - -__GLOBALS = locals() -for data in ModelType: - __GLOBALS[data.name] = data.value["type"] + def __str__(self): + return self.value["type"] def get_model( @@ -716,6 +714,7 @@ def get_model( or model_type == ModelType.BAICHUAN or model_type == ModelType.PHI3 ): + print(f">>> model_type: {model_type}") if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 437fbcab..b55ddc23 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import ( - UnquantizedWeight, Weights, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 98dae079..cfb8b1db 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -6,7 +6,6 @@ from typing import Optional from huggingface_hub import hf_hub_download from text_generation_server.utils.weights import ( DefaultWeightsLoader, - UnquantizedWeight, WeightsLoader, )