mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
build fbgemm
This commit is contained in:
parent
80087783a5
commit
985df12c46
14
Dockerfile
14
Dockerfile
@ -161,6 +161,14 @@ COPY server/custom_kernels/ .
|
|||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
# Build FBGEMM CUDA kernels
|
||||||
|
FROM kernel-builder AS fbgemm-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-fbgemm Makefile
|
||||||
|
RUN make build-fbgemm
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
# Build vllm CUDA kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
|
|
||||||
@ -225,10 +233,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
|
|||||||
# Copy build artifacts from marlin kernels builder
|
# Copy build artifacts from marlin kernels builder
|
||||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
# Copy build artifacts from fbgemm builder
|
||||||
# Copy builds artifacts from vllm builder
|
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||||
|
# Copy build artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from mamba builder
|
# Copy build artifacts from mamba builder
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
@ -5,6 +5,7 @@ include Makefile-awq
|
|||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
include Makefile-lorax-punica
|
include Makefile-lorax-punica
|
||||||
|
include Makefile-fbgemm
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
@ -20,16 +21,15 @@ gen-server:
|
|||||||
|
|
||||||
install-server: gen-server
|
install-server: gen-server
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
pip install -r requirements_intel.txt
|
pip install -r requirements_cuda.txt
|
||||||
pip install -e ".[accelerate, quantize, peft, outlines]"
|
pip install -e ".[accelerate, quantize, peft, outlines]"
|
||||||
|
|
||||||
|
|
||||||
install: install-cuda
|
install: install-cuda
|
||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||||
pip install -r requirements_cuda.txt
|
pip install -e ".[bnb]"
|
||||||
pip install -e ".[cuda, bnb]"
|
|
||||||
|
|
||||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
|
||||||
@ -37,6 +37,6 @@ run-dev:
|
|||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements_cuda.txt --without-hashes -E cuda
|
poetry export -o requirements_cuda.txt --without-hashes
|
||||||
poetry export -o requirements_rocm.txt --without-hashes
|
poetry export -o requirements_rocm.txt --without-hashes
|
||||||
poetry export -o requirements_intel.txt --without-hashes
|
poetry export -o requirements_intel.txt --without-hashes
|
||||||
|
15
server/Makefile-fbgemm
Normal file
15
server/Makefile-fbgemm
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
|
||||||
|
|
||||||
|
build-fbgemm:
|
||||||
|
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
|
||||||
|
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
||||||
|
cp fbgemm_remove_unused.patch fbgemm && \
|
||||||
|
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
|
||||||
|
git submodule update --init --recursive && \
|
||||||
|
cd fbgemm_gpu && \
|
||||||
|
pip install -r requirements.txt && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
||||||
|
|
||||||
|
install-fbgemm: build-fbgemm
|
||||||
|
cd fbgemm/fbgemm_gpu && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
306
server/fbgemm_remove_unused.patch
Normal file
306
server/fbgemm_remove_unused.patch
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
|
||||||
|
index 2244ea6f..96265a48 100644
|
||||||
|
--- a/fbgemm_gpu/CMakeLists.txt
|
||||||
|
+++ b/fbgemm_gpu/CMakeLists.txt
|
||||||
|
@@ -94,14 +94,14 @@ endif()
|
||||||
|
# Build Experimental Modules
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
- # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||||
|
- add_subdirectory(experimental/example)
|
||||||
|
-endif()
|
||||||
|
-
|
||||||
|
-if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
- add_subdirectory(experimental/gemm)
|
||||||
|
-endif()
|
||||||
|
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
+# # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||||
|
+# add_subdirectory(experimental/example)
|
||||||
|
+# endif()
|
||||||
|
+
|
||||||
|
+# if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
+# add_subdirectory(experimental/gemm)
|
||||||
|
+# endif()
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
|
||||||
|
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
index c56773fe..0c0d349e 100644
|
||||||
|
--- a/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
+++ b/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
set(fbgemm_gpu_sources_static_cpu
|
||||||
|
- codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||||
|
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||||
|
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||||
|
- codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||||
|
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||||
|
- src/input_combine_ops/input_combine_cpu.cpp
|
||||||
|
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||||
|
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||||
|
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||||
|
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||||
|
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||||
|
+ # src/input_combine_ops/input_combine_cpu.cpp
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_cpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_meta.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_cpu.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_meta.cpp
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||||
|
- src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||||
|
- src/split_embeddings_cache/lxu_cache.cpp
|
||||||
|
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||||
|
+ # src/sparse_ops/sparse_ops_cpu.cpp
|
||||||
|
+ # src/sparse_ops/sparse_ops_meta.cpp
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||||
|
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||||
|
+ # src/split_embeddings_cache/lxu_cache.cpp
|
||||||
|
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||||
|
+)
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
list(APPEND fbgemm_gpu_sources_static_cpu
|
||||||
|
- codegen/inference/embedding_forward_quantized_host.cpp
|
||||||
|
- codegen/utils/embedding_bounds_check_host.cpp
|
||||||
|
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||||
|
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||||
|
- src/memory_utils/memory_utils.cpp
|
||||||
|
- src/memory_utils/memory_utils_ops.cpp
|
||||||
|
- src/memory_utils/memory_utils_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_host.cpp
|
||||||
|
+ # codegen/utils/embedding_bounds_check_host.cpp
|
||||||
|
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||||
|
+ # src/memory_utils/memory_utils.cpp
|
||||||
|
+ # src/memory_utils/memory_utils_ops.cpp
|
||||||
|
+ # src/memory_utils/memory_utils_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_gpu.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_gpu.cpp
|
||||||
|
- src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||||
|
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||||
|
- src/metric_ops/metric_ops_host.cpp
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||||
|
- src/input_combine_ops/input_combine_gpu.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||||
|
+ # src/sparse_ops/sparse_ops_gpu.cpp
|
||||||
|
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||||
|
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||||
|
+ # src/metric_ops/metric_ops_host.cpp
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||||
|
+ # src/input_combine_ops/input_combine_gpu.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||||
|
+ )
|
||||||
|
|
||||||
|
if(NVML_LIB_PATH OR USE_ROCM)
|
||||||
|
message(STATUS "Adding merge_pooled_embeddings sources")
|
||||||
|
@@ -516,36 +518,36 @@ endif()
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
set(fbgemm_gpu_sources_static_gpu
|
||||||
|
- codegen/utils/embedding_bounds_check.cu
|
||||||
|
- codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||||
|
- src/histogram_binning_calibration_ops.cu
|
||||||
|
- src/input_combine_ops/input_combine.cu
|
||||||
|
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||||
|
- src/memory_utils/memory_utils.cu
|
||||||
|
- src/memory_utils/memory_utils_ops.cu
|
||||||
|
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||||
|
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||||
|
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||||
|
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||||
|
- src/layout_transform_ops/layout_transform_ops.cu
|
||||||
|
- src/metric_ops/metric_ops.cu
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||||
|
+ # codegen/utils/embedding_bounds_check.cu
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||||
|
+ # src/histogram_binning_calibration_ops.cu
|
||||||
|
+ # src/input_combine_ops/input_combine.cu
|
||||||
|
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||||
|
+ # src/memory_utils/memory_utils.cu
|
||||||
|
+ # src/memory_utils/memory_utils_ops.cu
|
||||||
|
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||||
|
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops.cu
|
||||||
|
+ # src/metric_ops/metric_ops.cu
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||||
|
src/quantize_ops/quantize_bfloat16.cu
|
||||||
|
src/quantize_ops/quantize_fp8_rowwise.cu
|
||||||
|
src/quantize_ops/quantize_fused_8bit_rowwise.cu
|
||||||
|
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
src/quantize_ops/quantize_msfp.cu
|
||||||
|
src/quantize_ops/quantize_padded_fp8_rowwise.cu
|
||||||
|
src/quantize_ops/quantize_mx.cu
|
||||||
|
- src/sparse_ops/sparse_async_cumsum.cu
|
||||||
|
- src/sparse_ops/sparse_block_bucketize_features.cu
|
||||||
|
- src/sparse_ops/sparse_bucketize_features.cu
|
||||||
|
- src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||||
|
- src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||||
|
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||||
|
- src/sparse_ops/sparse_group_index.cu
|
||||||
|
- src/sparse_ops/sparse_index_add.cu
|
||||||
|
- src/sparse_ops/sparse_index_select.cu
|
||||||
|
- src/sparse_ops/sparse_invert_permute.cu
|
||||||
|
- src/sparse_ops/sparse_pack_segments_backward.cu
|
||||||
|
- src/sparse_ops/sparse_pack_segments_forward.cu
|
||||||
|
- src/sparse_ops/sparse_permute_1d.cu
|
||||||
|
- src/sparse_ops/sparse_permute_2d.cu
|
||||||
|
- src/sparse_ops/sparse_permute102.cu
|
||||||
|
- src/sparse_ops/sparse_permute_embeddings.cu
|
||||||
|
- src/sparse_ops/sparse_range.cu
|
||||||
|
- src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||||
|
- src/sparse_ops/sparse_segment_sum_csr.cu
|
||||||
|
- src/sparse_ops/sparse_zipf.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_find.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_find.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||||
|
- src/split_embeddings_cache/lxu_cache.cu
|
||||||
|
- src/split_embeddings_cache/linearize_cache_indices.cu
|
||||||
|
- src/split_embeddings_cache/reset_weight_momentum.cu
|
||||||
|
- src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||||
|
- src/split_embeddings_utils/get_infos_metadata.cu
|
||||||
|
- src/split_embeddings_utils/radix_sort_pairs.cu
|
||||||
|
- src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||||
|
+ # src/sparse_ops/sparse_async_cumsum.cu
|
||||||
|
+ # src/sparse_ops/sparse_block_bucketize_features.cu
|
||||||
|
+ # src/sparse_ops/sparse_bucketize_features.cu
|
||||||
|
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||||
|
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||||
|
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||||
|
+ # src/sparse_ops/sparse_group_index.cu
|
||||||
|
+ # src/sparse_ops/sparse_index_add.cu
|
||||||
|
+ # src/sparse_ops/sparse_index_select.cu
|
||||||
|
+ # src/sparse_ops/sparse_invert_permute.cu
|
||||||
|
+ # src/sparse_ops/sparse_pack_segments_backward.cu
|
||||||
|
+ # src/sparse_ops/sparse_pack_segments_forward.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_1d.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_2d.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute102.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_embeddings.cu
|
||||||
|
+ # src/sparse_ops/sparse_range.cu
|
||||||
|
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||||
|
+ # src/sparse_ops/sparse_segment_sum_csr.cu
|
||||||
|
+ # src/sparse_ops/sparse_zipf.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_find.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_find.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||||
|
+ # src/split_embeddings_cache/lxu_cache.cu
|
||||||
|
+ # src/split_embeddings_cache/linearize_cache_indices.cu
|
||||||
|
+ # src/split_embeddings_cache/reset_weight_momentum.cu
|
||||||
|
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||||
|
+ # src/split_embeddings_utils/get_infos_metadata.cu
|
||||||
|
+ # src/split_embeddings_utils/radix_sort_pairs.cu
|
||||||
|
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||||
|
+ )
|
||||||
|
|
||||||
|
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
|
||||||
|
PROPERTIES COMPILE_OPTIONS
|
||||||
|
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
index 01f1d6ab..a6b8d7a8 100644
|
||||||
|
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
|
||||||
|
${THIRDPARTY}/json/include
|
||||||
|
${NCCL_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
-set(attention_ops_sources
|
||||||
|
- src/attention/attention.cpp
|
||||||
|
- src/attention/gqa_attn_splitk.cu)
|
||||||
|
+# set(attention_ops_sources
|
||||||
|
+# src/attention/attention.cpp
|
||||||
|
+# src/attention/gqa_attn_splitk.cu)
|
||||||
|
|
||||||
|
set(quantize_ops_sources
|
||||||
|
src/quantize/cutlass_extensions.cu
|
||||||
|
src/quantize/quantize.cu
|
||||||
|
src/quantize/quantize.cpp)
|
||||||
|
|
||||||
|
-set(comm_ops_sources
|
||||||
|
- src/comm/car.cu
|
||||||
|
- src/comm/car.cpp)
|
||||||
|
+# set(comm_ops_sources
|
||||||
|
+# src/comm/car.cu
|
||||||
|
+# src/comm/car.cpp)
|
||||||
|
|
||||||
|
set(experimental_gen_ai_cpp_source_files
|
||||||
|
- ${attention_ops_sources}
|
||||||
|
+ # ${attention_ops_sources}
|
||||||
|
${quantize_ops_sources}
|
||||||
|
- ${comm_ops_sources})
|
||||||
|
+ # ${comm_ops_sources}
|
||||||
|
+)
|
||||||
|
|
||||||
|
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
|
||||||
|
PROPERTIES INCLUDE_DIRECTORIES
|
11
server/fix_torch90a.sh
Executable file
11
server/fix_torch90a.sh
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This script is required to patch torch < 2.4
|
||||||
|
# It adds the 90a cuda target (H100)
|
||||||
|
# This target is required to build FBGEMM kernels
|
||||||
|
|
||||||
|
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
|
||||||
|
|
||||||
|
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
|
||||||
|
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
|
||||||
|
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch
|
1209
server/poetry.lock
generated
1209
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -34,14 +34,12 @@ peft = { version = "^0.10", optional = true }
|
|||||||
torch = { version = "^2.3.0", optional = true }
|
torch = { version = "^2.3.0", optional = true }
|
||||||
scipy = "^1.11.1"
|
scipy = "^1.11.1"
|
||||||
pillow = "^10.0.0"
|
pillow = "^10.0.0"
|
||||||
outlines= { version = "^0.0.46", optional = true }
|
outlines= { version = "^0.0.34", optional = true }
|
||||||
prometheus-client = "^0.20.0"
|
prometheus-client = "^0.20.0"
|
||||||
py-cpuinfo = "^9.0.0"
|
py-cpuinfo = "^9.0.0"
|
||||||
fbgemm-gpu = { version = "0.8.0rc4", optional = true }
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
cuda = ["fbgemm-gpu"]
|
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
|
@ -1,51 +1,48 @@
|
|||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fbgemm-gpu==0.8.0rc4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
|
@ -1,50 +1,48 @@
|
|||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
|
@ -1,50 +1,48 @@
|
|||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
|
@ -1,16 +1,19 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weight
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai
|
import fbgemm_gpu.experimental.gen_ai
|
||||||
|
|
||||||
HAS_FBGEMM = True
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
HAS_FBGEMM_MM = major == 9
|
||||||
|
HAS_FBGEMM_DYN = major >= 8
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
HAS_FBGEMM = False
|
HAS_FBGEMM_MM = False
|
||||||
|
HAS_FBGEMM_DYN = False
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear() -> torch.nn.Module:
|
def get_fp8_linear() -> torch.nn.Module:
|
||||||
@ -30,10 +33,7 @@ def get_fp8_linear() -> torch.nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||||
if HAS_FBGEMM:
|
if HAS_FBGEMM_DYN:
|
||||||
if scale_upper_bound.device != weight.device:
|
|
||||||
scale_upper_bound = scale_upper_bound.to(weight.device)
|
|
||||||
|
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||||
)
|
)
|
||||||
@ -55,11 +55,17 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Weight(Weight):
|
class Fp8Weight:
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
input_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
return get_fp8_linear()(self.weight, bias)
|
if self.weight_scale is None:
|
||||||
|
return get_fp8_linear().from_unquant(self.weight, bias)
|
||||||
|
return get_fp8_linear().from_fp8(
|
||||||
|
self.weight, self.weight_scale, self.input_scale, bias, bias.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8Linear(torch.nn.Module):
|
class Fp8Linear(torch.nn.Module):
|
||||||
@ -87,17 +93,17 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, bias, dtype):
|
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight.weight,
|
qweight=weight,
|
||||||
scale=weight.weight_scale,
|
scale=scale,
|
||||||
scale_upper_bound=weight.input_scale,
|
scale_upper_bound=input_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if HAS_FBGEMM:
|
if HAS_FBGEMM_MM:
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input, scale_upper_bound=self.scale_upper_bound
|
input, scale_upper_bound=self.scale_upper_bound
|
||||||
)
|
)
|
||||||
|
@ -139,6 +139,6 @@ def get_loader(
|
|||||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||||
)
|
)
|
||||||
elif quantize is None:
|
elif quantize is None:
|
||||||
return DefaultWeightsLoader(UnquantizedWeight)
|
return DefaultWeightsLoader()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
@ -2,15 +2,14 @@ import torch
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum, auto
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from text_generation_server.layers.fp8 import Fp8Weight
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class WeightsLoader(ABC):
|
class WeightsLoader(ABC):
|
||||||
"""
|
"""
|
||||||
@ -101,7 +100,7 @@ class UnquantizedWeight:
|
|||||||
class DefaultWeightsLoader(WeightsLoader):
|
class DefaultWeightsLoader(WeightsLoader):
|
||||||
"""Weight loader that loads (unquantized) Torch tensors."""
|
"""Weight loader that loads (unquantized) Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, weight_class):
|
def __init__(self, weight_class: Optional = None):
|
||||||
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
||||||
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
||||||
such as `Fp8Weight` can be used to quantize the weights during loading.
|
such as `Fp8Weight` can be used to quantize the weights during loading.
|
||||||
@ -122,51 +121,63 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
|
|
||||||
return self.weight_class(
|
|
||||||
weights.get_packed_sharded(
|
|
||||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
w = weights.get_packed_sharded(
|
w = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
)
|
)
|
||||||
# FP8 branch
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# FP8 branch
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
||||||
)
|
)
|
||||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||||
return w
|
|
||||||
|
if self.weight_class is None:
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
return self.weight_class(w)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
|
||||||
return self.weight_class(torch.cat(w, dim=dim))
|
|
||||||
|
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
w = torch.cat(w, dim=dim)
|
w = torch.cat(w, dim=dim)
|
||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
|
)
|
||||||
|
|
||||||
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
|
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
|
||||||
scale = torch.cat(scale, dim=0)
|
scale = torch.cat(scale, dim=0)
|
||||||
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
|
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
|
||||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||||
return w
|
|
||||||
|
if self.weight_class is None:
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
return self.weight_class(w)
|
||||||
|
|
||||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
return self.weight_class(
|
|
||||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
|
)
|
||||||
|
|
||||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
|
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
|
||||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||||
return w
|
|
||||||
|
if self.weight_class is None:
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
return self.weight_class(w)
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
|
Loading…
Reference in New Issue
Block a user