mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: lint and refactor import check and avoid model enum as global names
This commit is contained in:
parent
655a9d7ef3
commit
382bf59f4f
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
@ -12,17 +12,23 @@ from text_generation_server.utils.weights import (
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master, log_once
|
||||
import importlib.util
|
||||
|
||||
|
||||
FBGEMM_MM_AVAILABLE = False
|
||||
FBGEMM_DYN_AVAILABLE = False
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai
|
||||
|
||||
|
||||
def is_fbgemm_gpu_available():
|
||||
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
||||
|
||||
|
||||
if is_fbgemm_gpu_available():
|
||||
if SYSTEM == "cuda":
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
FBGEMM_MM_AVAILABLE = major == 9
|
||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
else:
|
||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||
|
||||
|
||||
|
@ -4,7 +4,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
import math
|
||||
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
@ -298,10 +298,8 @@ class ModelType(enum.Enum):
|
||||
"multimodal": True,
|
||||
}
|
||||
|
||||
|
||||
__GLOBALS = locals()
|
||||
for data in ModelType:
|
||||
__GLOBALS[data.name] = data.value["type"]
|
||||
def __str__(self):
|
||||
return self.value["type"]
|
||||
|
||||
|
||||
def get_model(
|
||||
@ -716,6 +714,7 @@ def get_model(
|
||||
or model_type == ModelType.BAICHUAN
|
||||
or model_type == ModelType.PHI3
|
||||
):
|
||||
print(f">>> model_type: {model_type}")
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
|
@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import (
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
@ -6,7 +6,6 @@ from typing import Optional
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user