avoid circular import and fix dockerfile

This commit is contained in:
OlivierDehaene 2024-07-19 18:56:41 +02:00
parent 985df12c46
commit 10cd8ab4a6
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 14 additions and 1 deletions

View File

@ -167,6 +167,9 @@ FROM kernel-builder AS fbgemm-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-fbgemm Makefile COPY server/Makefile-fbgemm Makefile
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
COPY server/fix_torch90a.sh fix_torch90a.sh
RUN make build-fbgemm RUN make build-fbgemm
# Build vllm CUDA kernels # Build vllm CUDA kernels

View File

@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Union
from safetensors import safe_open from safetensors import safe_open
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -126,10 +125,15 @@ class DefaultWeightsLoader(WeightsLoader):
) )
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
if self.weight_class is not None and self.weight_class != Fp8Weight: if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError( raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
) )
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
# FP8 branch # FP8 branch
scale = weights.get_packed_sharded( scale = weights.get_packed_sharded(
@ -148,6 +152,9 @@ class DefaultWeightsLoader(WeightsLoader):
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
if self.weight_class is not None and self.weight_class != Fp8Weight: if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError( raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
@ -166,6 +173,9 @@ class DefaultWeightsLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
if self.weight_class is not None and self.weight_class != Fp8Weight: if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError( raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"