mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
build fbgemm
This commit is contained in:
parent
80087783a5
commit
985df12c46
14
Dockerfile
14
Dockerfile
@ -161,6 +161,14 @@ COPY server/custom_kernels/ .
|
||||
# Build specific version of transformers
|
||||
RUN python setup.py build
|
||||
|
||||
# Build FBGEMM CUDA kernels
|
||||
FROM kernel-builder AS fbgemm-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-fbgemm Makefile
|
||||
RUN make build-fbgemm
|
||||
|
||||
# Build vllm CUDA kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
|
||||
@ -225,10 +233,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
|
||||
# Copy build artifacts from marlin kernels builder
|
||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
# Copy build artifacts from fbgemm builder
|
||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy build artifacts from mamba builder
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
|
@ -5,6 +5,7 @@ include Makefile-awq
|
||||
include Makefile-eetq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
include Makefile-fbgemm
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
@ -20,16 +21,15 @@ gen-server:
|
||||
|
||||
install-server: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements_intel.txt
|
||||
pip install -r requirements_cuda.txt
|
||||
pip install -e ".[accelerate, quantize, peft, outlines]"
|
||||
|
||||
|
||||
install: install-cuda
|
||||
echo "Installed server"
|
||||
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
||||
pip install -r requirements_cuda.txt
|
||||
pip install -e ".[cuda, bnb]"
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||
pip install -e ".[bnb]"
|
||||
|
||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||
|
||||
@ -37,6 +37,6 @@ run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
||||
export-requirements:
|
||||
poetry export -o requirements_cuda.txt --without-hashes -E cuda
|
||||
poetry export -o requirements_cuda.txt --without-hashes
|
||||
poetry export -o requirements_rocm.txt --without-hashes
|
||||
poetry export -o requirements_intel.txt --without-hashes
|
||||
|
15
server/Makefile-fbgemm
Normal file
15
server/Makefile-fbgemm
Normal file
@ -0,0 +1,15 @@
|
||||
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
|
||||
|
||||
build-fbgemm:
|
||||
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
|
||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
||||
cp fbgemm_remove_unused.patch fbgemm && \
|
||||
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
|
||||
git submodule update --init --recursive && \
|
||||
cd fbgemm_gpu && \
|
||||
pip install -r requirements.txt && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
||||
|
||||
install-fbgemm: build-fbgemm
|
||||
cd fbgemm/fbgemm_gpu && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
306
server/fbgemm_remove_unused.patch
Normal file
306
server/fbgemm_remove_unused.patch
Normal file
@ -0,0 +1,306 @@
|
||||
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
|
||||
index 2244ea6f..96265a48 100644
|
||||
--- a/fbgemm_gpu/CMakeLists.txt
|
||||
+++ b/fbgemm_gpu/CMakeLists.txt
|
||||
@@ -94,14 +94,14 @@ endif()
|
||||
# Build Experimental Modules
|
||||
################################################################################
|
||||
|
||||
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||
- # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||
- add_subdirectory(experimental/example)
|
||||
-endif()
|
||||
-
|
||||
-if(NOT FBGEMM_CPU_ONLY)
|
||||
- add_subdirectory(experimental/gemm)
|
||||
-endif()
|
||||
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||
+# # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||
+# add_subdirectory(experimental/example)
|
||||
+# endif()
|
||||
+
|
||||
+# if(NOT FBGEMM_CPU_ONLY)
|
||||
+# add_subdirectory(experimental/gemm)
|
||||
+# endif()
|
||||
|
||||
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
|
||||
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
|
||||
index c56773fe..0c0d349e 100644
|
||||
--- a/fbgemm_gpu/FbgemmGpu.cmake
|
||||
+++ b/fbgemm_gpu/FbgemmGpu.cmake
|
||||
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
|
||||
################################################################################
|
||||
|
||||
set(fbgemm_gpu_sources_static_cpu
|
||||
- codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||
- codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||
- src/input_combine_ops/input_combine_cpu.cpp
|
||||
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||
+ # src/input_combine_ops/input_combine_cpu.cpp
|
||||
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||
src/quantize_ops/quantize_ops_cpu.cpp
|
||||
src/quantize_ops/quantize_ops_meta.cpp
|
||||
- src/sparse_ops/sparse_ops_cpu.cpp
|
||||
- src/sparse_ops/sparse_ops_meta.cpp
|
||||
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||
- src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||
- src/split_embeddings_cache/lxu_cache.cpp
|
||||
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||
+ # src/sparse_ops/sparse_ops_cpu.cpp
|
||||
+ # src/sparse_ops/sparse_ops_meta.cpp
|
||||
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||
+ # src/split_embeddings_cache/lxu_cache.cpp
|
||||
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||
+)
|
||||
|
||||
if(NOT FBGEMM_CPU_ONLY)
|
||||
list(APPEND fbgemm_gpu_sources_static_cpu
|
||||
- codegen/inference/embedding_forward_quantized_host.cpp
|
||||
- codegen/utils/embedding_bounds_check_host.cpp
|
||||
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||
- src/memory_utils/memory_utils.cpp
|
||||
- src/memory_utils/memory_utils_ops.cpp
|
||||
- src/memory_utils/memory_utils_ops_cpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||
+ # codegen/inference/embedding_forward_quantized_host.cpp
|
||||
+ # codegen/utils/embedding_bounds_check_host.cpp
|
||||
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||
+ # src/memory_utils/memory_utils.cpp
|
||||
+ # src/memory_utils/memory_utils_ops.cpp
|
||||
+ # src/memory_utils/memory_utils_ops_cpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||
src/quantize_ops/quantize_ops_gpu.cpp
|
||||
- src/sparse_ops/sparse_ops_gpu.cpp
|
||||
- src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||
- src/metric_ops/metric_ops_host.cpp
|
||||
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||
- src/input_combine_ops/input_combine_gpu.cpp
|
||||
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||
+ # src/sparse_ops/sparse_ops_gpu.cpp
|
||||
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||
+ # src/metric_ops/metric_ops_host.cpp
|
||||
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||
+ # src/input_combine_ops/input_combine_gpu.cpp
|
||||
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||
+ )
|
||||
|
||||
if(NVML_LIB_PATH OR USE_ROCM)
|
||||
message(STATUS "Adding merge_pooled_embeddings sources")
|
||||
@@ -516,36 +518,36 @@ endif()
|
||||
|
||||
if(NOT FBGEMM_CPU_ONLY)
|
||||
set(fbgemm_gpu_sources_static_gpu
|
||||
- codegen/utils/embedding_bounds_check.cu
|
||||
- codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||
- src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||
- src/histogram_binning_calibration_ops.cu
|
||||
- src/input_combine_ops/input_combine.cu
|
||||
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||
- src/memory_utils/memory_utils.cu
|
||||
- src/memory_utils/memory_utils_ops.cu
|
||||
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||
- src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||
- src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||
- src/layout_transform_ops/layout_transform_ops.cu
|
||||
- src/metric_ops/metric_ops.cu
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||
+ # codegen/utils/embedding_bounds_check.cu
|
||||
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||
+ # src/histogram_binning_calibration_ops.cu
|
||||
+ # src/input_combine_ops/input_combine.cu
|
||||
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||
+ # src/memory_utils/memory_utils.cu
|
||||
+ # src/memory_utils/memory_utils_ops.cu
|
||||
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||
+ # src/layout_transform_ops/layout_transform_ops.cu
|
||||
+ # src/metric_ops/metric_ops.cu
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||
src/quantize_ops/quantize_bfloat16.cu
|
||||
src/quantize_ops/quantize_fp8_rowwise.cu
|
||||
src/quantize_ops/quantize_fused_8bit_rowwise.cu
|
||||
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
|
||||
src/quantize_ops/quantize_msfp.cu
|
||||
src/quantize_ops/quantize_padded_fp8_rowwise.cu
|
||||
src/quantize_ops/quantize_mx.cu
|
||||
- src/sparse_ops/sparse_async_cumsum.cu
|
||||
- src/sparse_ops/sparse_block_bucketize_features.cu
|
||||
- src/sparse_ops/sparse_bucketize_features.cu
|
||||
- src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||
- src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||
- src/sparse_ops/sparse_group_index.cu
|
||||
- src/sparse_ops/sparse_index_add.cu
|
||||
- src/sparse_ops/sparse_index_select.cu
|
||||
- src/sparse_ops/sparse_invert_permute.cu
|
||||
- src/sparse_ops/sparse_pack_segments_backward.cu
|
||||
- src/sparse_ops/sparse_pack_segments_forward.cu
|
||||
- src/sparse_ops/sparse_permute_1d.cu
|
||||
- src/sparse_ops/sparse_permute_2d.cu
|
||||
- src/sparse_ops/sparse_permute102.cu
|
||||
- src/sparse_ops/sparse_permute_embeddings.cu
|
||||
- src/sparse_ops/sparse_range.cu
|
||||
- src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||
- src/sparse_ops/sparse_segment_sum_csr.cu
|
||||
- src/sparse_ops/sparse_zipf.cu
|
||||
- src/split_embeddings_cache/lfu_cache_find.cu
|
||||
- src/split_embeddings_cache/lfu_cache_populate.cu
|
||||
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||
- src/split_embeddings_cache/lru_cache_find.cu
|
||||
- src/split_embeddings_cache/lru_cache_populate.cu
|
||||
- src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||
- src/split_embeddings_cache/lxu_cache.cu
|
||||
- src/split_embeddings_cache/linearize_cache_indices.cu
|
||||
- src/split_embeddings_cache/reset_weight_momentum.cu
|
||||
- src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||
- src/split_embeddings_utils/get_infos_metadata.cu
|
||||
- src/split_embeddings_utils/radix_sort_pairs.cu
|
||||
- src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||
+ # src/sparse_ops/sparse_async_cumsum.cu
|
||||
+ # src/sparse_ops/sparse_block_bucketize_features.cu
|
||||
+ # src/sparse_ops/sparse_bucketize_features.cu
|
||||
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||
+ # src/sparse_ops/sparse_group_index.cu
|
||||
+ # src/sparse_ops/sparse_index_add.cu
|
||||
+ # src/sparse_ops/sparse_index_select.cu
|
||||
+ # src/sparse_ops/sparse_invert_permute.cu
|
||||
+ # src/sparse_ops/sparse_pack_segments_backward.cu
|
||||
+ # src/sparse_ops/sparse_pack_segments_forward.cu
|
||||
+ # src/sparse_ops/sparse_permute_1d.cu
|
||||
+ # src/sparse_ops/sparse_permute_2d.cu
|
||||
+ # src/sparse_ops/sparse_permute102.cu
|
||||
+ # src/sparse_ops/sparse_permute_embeddings.cu
|
||||
+ # src/sparse_ops/sparse_range.cu
|
||||
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||
+ # src/sparse_ops/sparse_segment_sum_csr.cu
|
||||
+ # src/sparse_ops/sparse_zipf.cu
|
||||
+ # src/split_embeddings_cache/lfu_cache_find.cu
|
||||
+ # src/split_embeddings_cache/lfu_cache_populate.cu
|
||||
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||
+ # src/split_embeddings_cache/lru_cache_find.cu
|
||||
+ # src/split_embeddings_cache/lru_cache_populate.cu
|
||||
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||
+ # src/split_embeddings_cache/lxu_cache.cu
|
||||
+ # src/split_embeddings_cache/linearize_cache_indices.cu
|
||||
+ # src/split_embeddings_cache/reset_weight_momentum.cu
|
||||
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||
+ # src/split_embeddings_utils/get_infos_metadata.cu
|
||||
+ # src/split_embeddings_utils/radix_sort_pairs.cu
|
||||
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||
+ )
|
||||
|
||||
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
|
||||
PROPERTIES COMPILE_OPTIONS
|
||||
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||
index 01f1d6ab..a6b8d7a8 100644
|
||||
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
|
||||
${THIRDPARTY}/json/include
|
||||
${NCCL_INCLUDE_DIRS})
|
||||
|
||||
-set(attention_ops_sources
|
||||
- src/attention/attention.cpp
|
||||
- src/attention/gqa_attn_splitk.cu)
|
||||
+# set(attention_ops_sources
|
||||
+# src/attention/attention.cpp
|
||||
+# src/attention/gqa_attn_splitk.cu)
|
||||
|
||||
set(quantize_ops_sources
|
||||
src/quantize/cutlass_extensions.cu
|
||||
src/quantize/quantize.cu
|
||||
src/quantize/quantize.cpp)
|
||||
|
||||
-set(comm_ops_sources
|
||||
- src/comm/car.cu
|
||||
- src/comm/car.cpp)
|
||||
+# set(comm_ops_sources
|
||||
+# src/comm/car.cu
|
||||
+# src/comm/car.cpp)
|
||||
|
||||
set(experimental_gen_ai_cpp_source_files
|
||||
- ${attention_ops_sources}
|
||||
+ # ${attention_ops_sources}
|
||||
${quantize_ops_sources}
|
||||
- ${comm_ops_sources})
|
||||
+ # ${comm_ops_sources}
|
||||
+)
|
||||
|
||||
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
|
||||
PROPERTIES INCLUDE_DIRECTORIES
|
11
server/fix_torch90a.sh
Executable file
11
server/fix_torch90a.sh
Executable file
@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script is required to patch torch < 2.4
|
||||
# It adds the 90a cuda target (H100)
|
||||
# This target is required to build FBGEMM kernels
|
||||
|
||||
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
|
||||
|
||||
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
|
||||
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
|
||||
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch
|
1209
server/poetry.lock
generated
1209
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -34,14 +34,12 @@ peft = { version = "^0.10", optional = true }
|
||||
torch = { version = "^2.3.0", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
outlines= { version = "^0.0.46", optional = true }
|
||||
outlines= { version = "^0.0.34", optional = true }
|
||||
prometheus-client = "^0.20.0"
|
||||
py-cpuinfo = "^9.0.0"
|
||||
fbgemm-gpu = { version = "0.8.0rc4", optional = true }
|
||||
|
||||
[tool.poetry.extras]
|
||||
torch = ["torch"]
|
||||
cuda = ["fbgemm-gpu"]
|
||||
accelerate = ["accelerate"]
|
||||
bnb = ["bitsandbytes"]
|
||||
peft = ["peft"]
|
||||
|
@ -1,51 +1,48 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fbgemm-gpu==0.8.0rc4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,50 +1,48 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,50 +1,48 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,16 +1,19 @@
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weight
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai
|
||||
|
||||
HAS_FBGEMM = True
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
HAS_FBGEMM_MM = major == 9
|
||||
HAS_FBGEMM_DYN = major >= 8
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
HAS_FBGEMM = False
|
||||
HAS_FBGEMM_MM = False
|
||||
HAS_FBGEMM_DYN = False
|
||||
|
||||
|
||||
def get_fp8_linear() -> torch.nn.Module:
|
||||
@ -30,10 +33,7 @@ def get_fp8_linear() -> torch.nn.Module:
|
||||
|
||||
|
||||
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||
if HAS_FBGEMM:
|
||||
if scale_upper_bound.device != weight.device:
|
||||
scale_upper_bound = scale_upper_bound.to(weight.device)
|
||||
|
||||
if HAS_FBGEMM_DYN:
|
||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||
)
|
||||
@ -55,11 +55,17 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Weight(Weight):
|
||||
class Fp8Weight:
|
||||
weight: torch.Tensor
|
||||
weight_scale: Optional[torch.Tensor] = None
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return get_fp8_linear()(self.weight, bias)
|
||||
if self.weight_scale is None:
|
||||
return get_fp8_linear().from_unquant(self.weight, bias)
|
||||
return get_fp8_linear().from_fp8(
|
||||
self.weight, self.weight_scale, self.input_scale, bias, bias.dtype
|
||||
)
|
||||
|
||||
|
||||
class Fp8Linear(torch.nn.Module):
|
||||
@ -87,17 +93,17 @@ class Fp8Linear(torch.nn.Module):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, bias, dtype):
|
||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||
return cls(
|
||||
qweight=weight.weight,
|
||||
scale=weight.weight_scale,
|
||||
scale_upper_bound=weight.input_scale,
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
scale_upper_bound=input_scale,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if HAS_FBGEMM:
|
||||
if HAS_FBGEMM_MM:
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound
|
||||
)
|
||||
|
@ -139,6 +139,6 @@ def get_loader(
|
||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||
)
|
||||
elif quantize is None:
|
||||
return DefaultWeightsLoader(UnquantizedWeight)
|
||||
return DefaultWeightsLoader()
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||
|
@ -2,15 +2,14 @@ import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from safetensors import safe_open
|
||||
from dataclasses import dataclass
|
||||
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class WeightsLoader(ABC):
|
||||
"""
|
||||
@ -101,7 +100,7 @@ class UnquantizedWeight:
|
||||
class DefaultWeightsLoader(WeightsLoader):
|
||||
"""Weight loader that loads (unquantized) Torch tensors."""
|
||||
|
||||
def __init__(self, weight_class):
|
||||
def __init__(self, weight_class: Optional = None):
|
||||
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
||||
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
||||
such as `Fp8Weight` can be used to quantize the weights during loading.
|
||||
@ -122,51 +121,63 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
|
||||
return self.weight_class(
|
||||
weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
),
|
||||
)
|
||||
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
# FP8 branch
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||
raise RuntimeError(
|
||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||
)
|
||||
|
||||
# FP8 branch
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
return w
|
||||
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
|
||||
if self.weight_class is None:
|
||||
return UnquantizedWeight(w)
|
||||
return self.weight_class(w)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
return self.weight_class(torch.cat(w, dim=dim))
|
||||
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
w = torch.cat(w, dim=dim)
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||
raise RuntimeError(
|
||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||
)
|
||||
|
||||
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
|
||||
scale = torch.cat(scale, dim=0)
|
||||
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
|
||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
return w
|
||||
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
|
||||
if self.weight_class is None:
|
||||
return UnquantizedWeight(w)
|
||||
return self.weight_class(w)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
return self.weight_class(
|
||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||
)
|
||||
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_class is not None and self.weight_class != Fp8Weight:
|
||||
raise RuntimeError(
|
||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||
)
|
||||
|
||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
|
||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
return w
|
||||
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
|
||||
if self.weight_class is None:
|
||||
return UnquantizedWeight(w)
|
||||
return self.weight_class(w)
|
||||
|
||||
|
||||
class Weights:
|
||||
|
Loading…
Reference in New Issue
Block a user