mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
avoid circular import and fix dockerfile
This commit is contained in:
parent
985df12c46
commit
10cd8ab4a6
@ -167,6 +167,9 @@ FROM kernel-builder AS fbgemm-builder
|
|||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
COPY server/Makefile-fbgemm Makefile
|
COPY server/Makefile-fbgemm Makefile
|
||||||
|
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
|
||||||
|
COPY server/fix_torch90a.sh fix_torch90a.sh
|
||||||
|
|
||||||
RUN make build-fbgemm
|
RUN make build-fbgemm
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
# Build vllm CUDA kernels
|
||||||
|
@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Union
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
@ -126,10 +125,15 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FIXME: here to avoid circular import
|
||||||
|
from text_generation_server.layers.fp8 import Fp8Weight
|
||||||
|
|
||||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
)
|
)
|
||||||
|
# FIXME: here to avoid circular import
|
||||||
|
from text_generation_server.layers.fp8 import Fp8Weight
|
||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
@ -148,6 +152,9 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FIXME: here to avoid circular import
|
||||||
|
from text_generation_server.layers.fp8 import Fp8Weight
|
||||||
|
|
||||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
@ -166,6 +173,9 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FIXME: here to avoid circular import
|
||||||
|
from text_generation_server.layers.fp8 import Fp8Weight
|
||||||
|
|
||||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
|
Loading…
Reference in New Issue
Block a user