diff --git a/Dockerfile b/Dockerfile index aa754124..54ddd5ef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -167,6 +167,9 @@ FROM kernel-builder AS fbgemm-builder WORKDIR /usr/src COPY server/Makefile-fbgemm Makefile +COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch +COPY server/fix_torch90a.sh fix_torch90a.sh + RUN make build-fbgemm # Build vllm CUDA kernels diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 85360478..491f92ea 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Union from safetensors import safe_open from dataclasses import dataclass -from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.utils.import_utils import SYSTEM @@ -126,10 +125,15 @@ class DefaultWeightsLoader(WeightsLoader): ) if w.dtype == torch.float8_e4m3fn: + # FIXME: here to avoid circular import + from text_generation_server.layers.fp8 import Fp8Weight + if self.weight_class is not None and self.weight_class != Fp8Weight: raise RuntimeError( f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" ) + # FIXME: here to avoid circular import + from text_generation_server.layers.fp8 import Fp8Weight # FP8 branch scale = weights.get_packed_sharded( @@ -148,6 +152,9 @@ class DefaultWeightsLoader(WeightsLoader): # FP8 branch if w.dtype == torch.float8_e4m3fn: + # FIXME: here to avoid circular import + from text_generation_server.layers.fp8 import Fp8Weight + if self.weight_class is not None and self.weight_class != Fp8Weight: raise RuntimeError( f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" @@ -166,6 +173,9 @@ class DefaultWeightsLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: + # FIXME: here to avoid circular import + from text_generation_server.layers.fp8 import Fp8Weight + if self.weight_class is not None and self.weight_class != Fp8Weight: raise RuntimeError( f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"