mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: lint and refactor import check and avoid model enum as global names
This commit is contained in:
parent
655a9d7ef3
commit
382bf59f4f
@ -1,5 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
@ -12,17 +12,23 @@ from text_generation_server.utils.weights import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_master, log_once
|
from text_generation_server.utils.log import log_master, log_once
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
FBGEMM_MM_AVAILABLE = False
|
FBGEMM_MM_AVAILABLE = False
|
||||||
FBGEMM_DYN_AVAILABLE = False
|
FBGEMM_DYN_AVAILABLE = False
|
||||||
try:
|
|
||||||
import fbgemm_gpu.experimental.gen_ai
|
|
||||||
|
|
||||||
|
|
||||||
|
def is_fbgemm_gpu_available():
|
||||||
|
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
||||||
|
|
||||||
|
|
||||||
|
if is_fbgemm_gpu_available():
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
major, _ = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
FBGEMM_MM_AVAILABLE = major == 9
|
FBGEMM_MM_AVAILABLE = major == 9
|
||||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||||
except (ImportError, ModuleNotFoundError):
|
else:
|
||||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -298,10 +298,8 @@ class ModelType(enum.Enum):
|
|||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
__GLOBALS = locals()
|
return self.value["type"]
|
||||||
for data in ModelType:
|
|
||||||
__GLOBALS[data.name] = data.value["type"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
@ -716,6 +714,7 @@ def get_model(
|
|||||||
or model_type == ModelType.BAICHUAN
|
or model_type == ModelType.BAICHUAN
|
||||||
or model_type == ModelType.PHI3
|
or model_type == ModelType.PHI3
|
||||||
):
|
):
|
||||||
|
print(f">>> model_type: {model_type}")
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
|
|||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
UnquantizedWeight,
|
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
@ -6,7 +6,6 @@ from typing import Optional
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
UnquantizedWeight,
|
|
||||||
WeightsLoader,
|
WeightsLoader,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user