fix: lint and refactor import check and avoid model enum as global names

This commit is contained in:
drbh 2024-07-24 14:51:28 +00:00
parent 655a9d7ef3
commit 382bf59f4f
6 changed files with 12 additions and 11 deletions

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb
import torch

View File

@ -12,17 +12,23 @@ from text_generation_server.utils.weights import (
Weights,
)
from text_generation_server.utils.log import log_master, log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
def is_fbgemm_gpu_available():
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError):
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")

View File

@ -4,7 +4,6 @@ import torch
from torch import nn
# Inverse dim formula to find dim based on number of rotations
import math
from text_generation_server.utils.import_utils import SYSTEM

View File

@ -298,10 +298,8 @@ class ModelType(enum.Enum):
"multimodal": True,
}
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
def __str__(self):
return self.value["type"]
def get_model(
@ -716,6 +714,7 @@ def get_model(
or model_type == ModelType.BAICHUAN
or model_type == ModelType.PHI3
):
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,

View File

@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import (
UnquantizedWeight,
Weights,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader

View File

@ -6,7 +6,6 @@ from typing import Optional
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader,
)