build fbgemm

This commit is contained in:
OlivierDehaene 2024-07-19 18:26:50 +02:00
parent 80087783a5
commit 985df12c46
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
13 changed files with 1043 additions and 764 deletions

View File

@ -161,6 +161,14 @@ COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build
# Build FBGEMM CUDA kernels
FROM kernel-builder AS fbgemm-builder
WORKDIR /usr/src
COPY server/Makefile-fbgemm Makefile
RUN make build-fbgemm
# Build vllm CUDA kernels
FROM kernel-builder AS vllm-builder
@ -225,10 +233,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder
# Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages

View File

@ -5,6 +5,7 @@ include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-fbgemm
unit-tests:
pytest -s -vv -m "not private" tests
@ -20,16 +21,15 @@ gen-server:
install-server: gen-server
pip install pip --upgrade
pip install -r requirements_intel.txt
pip install -r requirements_cuda.txt
pip install -e ".[accelerate, quantize, peft, outlines]"
install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
pip install -r requirements_cuda.txt
pip install -e ".[cuda, bnb]"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]"
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
@ -37,6 +37,6 @@ run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
export-requirements:
poetry export -o requirements_cuda.txt --without-hashes -E cuda
poetry export -o requirements_cuda.txt --without-hashes
poetry export -o requirements_rocm.txt --without-hashes
poetry export -o requirements_intel.txt --without-hashes

15
server/Makefile-fbgemm Normal file
View File

@ -0,0 +1,15 @@
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
build-fbgemm:
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
cp fbgemm_remove_unused.patch fbgemm && \
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
git submodule update --init --recursive && \
cd fbgemm_gpu && \
pip install -r requirements.txt && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
install-fbgemm: build-fbgemm
cd fbgemm/fbgemm_gpu && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install

View File

@ -0,0 +1,306 @@
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
index 2244ea6f..96265a48 100644
--- a/fbgemm_gpu/CMakeLists.txt
+++ b/fbgemm_gpu/CMakeLists.txt
@@ -94,14 +94,14 @@ endif()
# Build Experimental Modules
################################################################################
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
- # TODO: Figure out NCCL/RCCL integration with ROCm
- add_subdirectory(experimental/example)
-endif()
-
-if(NOT FBGEMM_CPU_ONLY)
- add_subdirectory(experimental/gemm)
-endif()
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
+# # TODO: Figure out NCCL/RCCL integration with ROCm
+# add_subdirectory(experimental/example)
+# endif()
+
+# if(NOT FBGEMM_CPU_ONLY)
+# add_subdirectory(experimental/gemm)
+# endif()
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
index c56773fe..0c0d349e 100644
--- a/fbgemm_gpu/FbgemmGpu.cmake
+++ b/fbgemm_gpu/FbgemmGpu.cmake
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
################################################################################
set(fbgemm_gpu_sources_static_cpu
- codegen/training/forward/embedding_forward_split_cpu.cpp
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
- codegen/utils/embedding_bounds_check_host_cpu.cpp
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
- src/input_combine_ops/input_combine_cpu.cpp
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
+ # src/input_combine_ops/input_combine_cpu.cpp
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
src/quantize_ops/quantize_ops_cpu.cpp
src/quantize_ops/quantize_ops_meta.cpp
- src/sparse_ops/sparse_ops_cpu.cpp
- src/sparse_ops/sparse_ops_meta.cpp
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
- src/split_embeddings_cache/linearize_cache_indices.cpp
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
- src/split_embeddings_cache/lxu_cache.cpp
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
+ # src/sparse_ops/sparse_ops_cpu.cpp
+ # src/sparse_ops/sparse_ops_meta.cpp
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
+ # src/split_embeddings_cache/lxu_cache.cpp
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
+)
if(NOT FBGEMM_CPU_ONLY)
list(APPEND fbgemm_gpu_sources_static_cpu
- codegen/inference/embedding_forward_quantized_host.cpp
- codegen/utils/embedding_bounds_check_host.cpp
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
- src/memory_utils/memory_utils.cpp
- src/memory_utils/memory_utils_ops.cpp
- src/memory_utils/memory_utils_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
+ # codegen/inference/embedding_forward_quantized_host.cpp
+ # codegen/utils/embedding_bounds_check_host.cpp
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
+ # src/memory_utils/memory_utils.cpp
+ # src/memory_utils/memory_utils_ops.cpp
+ # src/memory_utils/memory_utils_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
src/quantize_ops/quantize_ops_gpu.cpp
- src/sparse_ops/sparse_ops_gpu.cpp
- src/split_embeddings_utils/split_embeddings_utils.cpp
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
- src/metric_ops/metric_ops_host.cpp
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
- src/input_combine_ops/input_combine_gpu.cpp
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
+ # src/sparse_ops/sparse_ops_gpu.cpp
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
+ # src/metric_ops/metric_ops_host.cpp
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
+ # src/input_combine_ops/input_combine_gpu.cpp
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
+ )
if(NVML_LIB_PATH OR USE_ROCM)
message(STATUS "Adding merge_pooled_embeddings sources")
@@ -516,36 +518,36 @@ endif()
if(NOT FBGEMM_CPU_ONLY)
set(fbgemm_gpu_sources_static_gpu
- codegen/utils/embedding_bounds_check.cu
- codegen/inference/embedding_forward_quantized_split_lookup.cu
- src/embedding_inplace_ops/embedding_inplace_update.cu
- src/histogram_binning_calibration_ops.cu
- src/input_combine_ops/input_combine.cu
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
- src/memory_utils/memory_utils.cu
- src/memory_utils/memory_utils_ops.cu
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
- src/jagged_tensor_ops/jagged_softmax_backward.cu
- src/jagged_tensor_ops/jagged_softmax_forward.cu
- src/jagged_tensor_ops/jagged_tensor_ops.cu
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
- src/jagged_tensor_ops/jagged_unique_indices.cu
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
- src/layout_transform_ops/layout_transform_ops.cu
- src/metric_ops/metric_ops.cu
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
+ # codegen/utils/embedding_bounds_check.cu
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
+ # src/histogram_binning_calibration_ops.cu
+ # src/input_combine_ops/input_combine.cu
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
+ # src/memory_utils/memory_utils.cu
+ # src/memory_utils/memory_utils_ops.cu
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
+ # src/layout_transform_ops/layout_transform_ops.cu
+ # src/metric_ops/metric_ops.cu
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
src/quantize_ops/quantize_bfloat16.cu
src/quantize_ops/quantize_fp8_rowwise.cu
src/quantize_ops/quantize_fused_8bit_rowwise.cu
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
src/quantize_ops/quantize_msfp.cu
src/quantize_ops/quantize_padded_fp8_rowwise.cu
src/quantize_ops/quantize_mx.cu
- src/sparse_ops/sparse_async_cumsum.cu
- src/sparse_ops/sparse_block_bucketize_features.cu
- src/sparse_ops/sparse_bucketize_features.cu
- src/sparse_ops/sparse_batched_unary_embeddings.cu
- src/sparse_ops/sparse_compute_frequency_sequence.cu
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
- src/sparse_ops/sparse_group_index.cu
- src/sparse_ops/sparse_index_add.cu
- src/sparse_ops/sparse_index_select.cu
- src/sparse_ops/sparse_invert_permute.cu
- src/sparse_ops/sparse_pack_segments_backward.cu
- src/sparse_ops/sparse_pack_segments_forward.cu
- src/sparse_ops/sparse_permute_1d.cu
- src/sparse_ops/sparse_permute_2d.cu
- src/sparse_ops/sparse_permute102.cu
- src/sparse_ops/sparse_permute_embeddings.cu
- src/sparse_ops/sparse_range.cu
- src/sparse_ops/sparse_reorder_batched_ad.cu
- src/sparse_ops/sparse_segment_sum_csr.cu
- src/sparse_ops/sparse_zipf.cu
- src/split_embeddings_cache/lfu_cache_find.cu
- src/split_embeddings_cache/lfu_cache_populate.cu
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
- src/split_embeddings_cache/lru_cache_find.cu
- src/split_embeddings_cache/lru_cache_populate.cu
- src/split_embeddings_cache/lru_cache_populate_byte.cu
- src/split_embeddings_cache/lxu_cache.cu
- src/split_embeddings_cache/linearize_cache_indices.cu
- src/split_embeddings_cache/reset_weight_momentum.cu
- src/split_embeddings_utils/generate_vbe_metadata.cu
- src/split_embeddings_utils/get_infos_metadata.cu
- src/split_embeddings_utils/radix_sort_pairs.cu
- src/split_embeddings_utils/transpose_embedding_input.cu)
+ # src/sparse_ops/sparse_async_cumsum.cu
+ # src/sparse_ops/sparse_block_bucketize_features.cu
+ # src/sparse_ops/sparse_bucketize_features.cu
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
+ # src/sparse_ops/sparse_group_index.cu
+ # src/sparse_ops/sparse_index_add.cu
+ # src/sparse_ops/sparse_index_select.cu
+ # src/sparse_ops/sparse_invert_permute.cu
+ # src/sparse_ops/sparse_pack_segments_backward.cu
+ # src/sparse_ops/sparse_pack_segments_forward.cu
+ # src/sparse_ops/sparse_permute_1d.cu
+ # src/sparse_ops/sparse_permute_2d.cu
+ # src/sparse_ops/sparse_permute102.cu
+ # src/sparse_ops/sparse_permute_embeddings.cu
+ # src/sparse_ops/sparse_range.cu
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
+ # src/sparse_ops/sparse_segment_sum_csr.cu
+ # src/sparse_ops/sparse_zipf.cu
+ # src/split_embeddings_cache/lfu_cache_find.cu
+ # src/split_embeddings_cache/lfu_cache_populate.cu
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
+ # src/split_embeddings_cache/lru_cache_find.cu
+ # src/split_embeddings_cache/lru_cache_populate.cu
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
+ # src/split_embeddings_cache/lxu_cache.cu
+ # src/split_embeddings_cache/linearize_cache_indices.cu
+ # src/split_embeddings_cache/reset_weight_momentum.cu
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
+ # src/split_embeddings_utils/get_infos_metadata.cu
+ # src/split_embeddings_utils/radix_sort_pairs.cu
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
+ )
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
PROPERTIES COMPILE_OPTIONS
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
index 01f1d6ab..a6b8d7a8 100644
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
${THIRDPARTY}/json/include
${NCCL_INCLUDE_DIRS})
-set(attention_ops_sources
- src/attention/attention.cpp
- src/attention/gqa_attn_splitk.cu)
+# set(attention_ops_sources
+# src/attention/attention.cpp
+# src/attention/gqa_attn_splitk.cu)
set(quantize_ops_sources
src/quantize/cutlass_extensions.cu
src/quantize/quantize.cu
src/quantize/quantize.cpp)
-set(comm_ops_sources
- src/comm/car.cu
- src/comm/car.cpp)
+# set(comm_ops_sources
+# src/comm/car.cu
+# src/comm/car.cpp)
set(experimental_gen_ai_cpp_source_files
- ${attention_ops_sources}
+ # ${attention_ops_sources}
${quantize_ops_sources}
- ${comm_ops_sources})
+ # ${comm_ops_sources}
+)
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
PROPERTIES INCLUDE_DIRECTORIES

11
server/fix_torch90a.sh Executable file
View File

@ -0,0 +1,11 @@
#!/bin/bash
# This script is required to patch torch < 2.4
# It adds the 90a cuda target (H100)
# This target is required to build FBGEMM kernels
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch

1209
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -34,14 +34,12 @@ peft = { version = "^0.10", optional = true }
torch = { version = "^2.3.0", optional = true }
scipy = "^1.11.1"
pillow = "^10.0.0"
outlines= { version = "^0.0.46", optional = true }
outlines= { version = "^0.0.34", optional = true }
prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0"
fbgemm-gpu = { version = "0.8.0rc4", optional = true }
[tool.poetry.extras]
torch = ["torch"]
cuda = ["fbgemm-gpu"]
accelerate = ["accelerate"]
bnb = ["bitsandbytes"]
peft = ["peft"]

View File

@ -1,51 +1,48 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
fbgemm-gpu==0.8.0rc4 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,50 +1,48 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,50 +1,48 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,16 +1,19 @@
import torch
from dataclasses import dataclass
from typing import Optional
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weight
try:
import fbgemm_gpu.experimental.gen_ai
HAS_FBGEMM = True
major, _ = torch.cuda.get_device_capability()
HAS_FBGEMM_MM = major == 9
HAS_FBGEMM_DYN = major >= 8
except (ImportError, ModuleNotFoundError):
HAS_FBGEMM = False
HAS_FBGEMM_MM = False
HAS_FBGEMM_DYN = False
def get_fp8_linear() -> torch.nn.Module:
@ -30,10 +33,7 @@ def get_fp8_linear() -> torch.nn.Module:
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
if HAS_FBGEMM:
if scale_upper_bound.device != weight.device:
scale_upper_bound = scale_upper_bound.to(weight.device)
if HAS_FBGEMM_DYN:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
@ -55,11 +55,17 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
@dataclass
class Fp8Weight(Weight):
class Fp8Weight:
weight: torch.Tensor
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
def get_linear(self, bias: torch.Tensor):
return get_fp8_linear()(self.weight, bias)
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias)
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.input_scale, bias, bias.dtype
)
class Fp8Linear(torch.nn.Module):
@ -87,17 +93,17 @@ class Fp8Linear(torch.nn.Module):
)
@classmethod
def from_fp8(cls, weight, bias, dtype):
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
return cls(
qweight=weight.weight,
scale=weight.weight_scale,
scale_upper_bound=weight.input_scale,
qweight=weight,
scale=scale,
scale_upper_bound=input_scale,
bias=bias,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if HAS_FBGEMM:
if HAS_FBGEMM_MM:
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
)

View File

@ -139,6 +139,6 @@ def get_loader(
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
elif quantize is None:
return DefaultWeightsLoader(UnquantizedWeight)
return DefaultWeightsLoader()
else:
raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -2,15 +2,14 @@ import torch
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Dict, List, Optional, Union
from text_generation_server.utils.import_utils import SYSTEM
from safetensors import safe_open
from dataclasses import dataclass
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.utils.import_utils import SYSTEM
class WeightsLoader(ABC):
"""
@ -101,7 +100,7 @@ class UnquantizedWeight:
class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class):
def __init__(self, weight_class: Optional = None):
"""Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading.
@ -122,51 +121,63 @@ class DefaultWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
return self.weight_class(
weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
),
)
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
if self.weight_class is None:
return UnquantizedWeight(w)
return self.weight_class(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = torch.cat(w, dim=dim)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
scale = torch.cat(scale, dim=0)
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
if self.weight_class is None:
return UnquantizedWeight(w)
return self.weight_class(w)
def get_weights_row(self, weights: "Weights", prefix: str):
return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1),
)
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_class is not None and self.weight_class != Fp8Weight:
raise RuntimeError(
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
if self.weight_class is None:
return UnquantizedWeight(w)
return self.weight_class(w)
class Weights: