fix: lint and refactor import check and avoid model enum as global names

This commit is contained in:
drbh 2024-07-24 14:51:28 +00:00
parent 655a9d7ef3
commit 382bf59f4f
6 changed files with 12 additions and 11 deletions

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch

View File

@ -12,17 +12,23 @@ from text_generation_server.utils.weights import (
Weights, Weights,
) )
from text_generation_server.utils.log import log_master, log_once from text_generation_server.utils.log import log_master, log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
def is_fbgemm_gpu_available():
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
if is_fbgemm_gpu_available():
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9 FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8 FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError): else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")

View File

@ -4,7 +4,6 @@ import torch
from torch import nn from torch import nn
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
import math
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM

View File

@ -298,10 +298,8 @@ class ModelType(enum.Enum):
"multimodal": True, "multimodal": True,
} }
def __str__(self):
__GLOBALS = locals() return self.value["type"]
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
def get_model( def get_model(
@ -716,6 +714,7 @@ def get_model(
or model_type == ModelType.BAICHUAN or model_type == ModelType.BAICHUAN
or model_type == ModelType.PHI3 or model_type == ModelType.PHI3
): ):
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,

View File

@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
UnquantizedWeight,
Weights, Weights,
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader

View File

@ -6,7 +6,6 @@ from typing import Optional
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader, WeightsLoader,
) )