mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
avoid circular import and fix dockerfile
This commit is contained in:
parent
985df12c46
commit
10cd8ab4a6
@ -167,6 +167,9 @@ FROM kernel-builder AS fbgemm-builder
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-fbgemm Makefile
|
||||
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
|
||||
COPY server/fix_torch90a.sh fix_torch90a.sh
|
||||
|
||||
RUN make build-fbgemm
|
||||
|
||||
# Build vllm CUDA kernels
|
||||
|
@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Union
|
||||
from safetensors import safe_open
|
||||
from dataclasses import dataclass
|
||||
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
@ -126,10 +125,15 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
)
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FIXME: here to avoid circular import
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
|
||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||
raise RuntimeError(
|
||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||
)
|
||||
# FIXME: here to avoid circular import
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
|
||||
# FP8 branch
|
||||
scale = weights.get_packed_sharded(
|
||||
@ -148,6 +152,9 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FIXME: here to avoid circular import
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
|
||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||
raise RuntimeError(
|
||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||
@ -166,6 +173,9 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FIXME: here to avoid circular import
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
|
||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||
raise RuntimeError(
|
||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||
|
Loading…
Reference in New Issue
Block a user