diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index abe161db..6c968053 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -183,4 +183,3 @@ jobs: export HF_TOKEN=${{ secrets.HF_TOKEN }} echo $DOCKER_IMAGE pytest -s -vv integration-tests ${PYTEST_FLAGS} - diff --git a/Dockerfile b/Dockerfile index 3f2e8ef0..54ddd5ef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -161,6 +161,17 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build +# Build FBGEMM CUDA kernels +FROM kernel-builder AS fbgemm-builder + +WORKDIR /usr/src + +COPY server/Makefile-fbgemm Makefile +COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch +COPY server/fix_torch90a.sh fix_torch90a.sh + +RUN make build-fbgemm + # Build vllm CUDA kernels FROM kernel-builder AS vllm-builder @@ -225,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 # Copy build artifacts from marlin kernels builder COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - -# Copy builds artifacts from vllm builder +# Copy build artifacts from fbgemm builder +COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages diff --git a/server/Makefile b/server/Makefile index 33940655..209fc44e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -5,6 +5,7 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica +include Makefile-fbgemm unit-tests: pytest -s -vv -m "not private" tests @@ -27,8 +28,9 @@ install-server: gen-server install: install-cuda echo "Installed server" -install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention +install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm pip install -e ".[bnb]" + pip install nvidia-nccl-cu12==2.22.3 install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm @@ -36,5 +38,6 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements_cuda.txt --without-hashes --with cuda + poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes + poetry export -o requirements_intel.txt --without-hashes diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm new file mode 100644 index 00000000..38f8f31f --- /dev/null +++ b/server/Makefile-fbgemm @@ -0,0 +1,15 @@ +fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca + +build-fbgemm: + chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \ + git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ + cp fbgemm_remove_unused.patch fbgemm && \ + cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \ + git submodule update --init --recursive && \ + cd fbgemm_gpu && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build + +install-fbgemm: build-fbgemm + cd fbgemm/fbgemm_gpu && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/server/fbgemm_remove_unused.patch b/server/fbgemm_remove_unused.patch new file mode 100644 index 00000000..ad6af811 --- /dev/null +++ b/server/fbgemm_remove_unused.patch @@ -0,0 +1,306 @@ +diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt +index 2244ea6f..96265a48 100644 +--- a/fbgemm_gpu/CMakeLists.txt ++++ b/fbgemm_gpu/CMakeLists.txt +@@ -94,14 +94,14 @@ endif() + # Build Experimental Modules + ################################################################################ + +-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) +- # TODO: Figure out NCCL/RCCL integration with ROCm +- add_subdirectory(experimental/example) +-endif() +- +-if(NOT FBGEMM_CPU_ONLY) +- add_subdirectory(experimental/gemm) +-endif() ++# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) ++# # TODO: Figure out NCCL/RCCL integration with ROCm ++# add_subdirectory(experimental/example) ++# endif() ++ ++# if(NOT FBGEMM_CPU_ONLY) ++# add_subdirectory(experimental/gemm) ++# endif() + + if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) + # CUTLASS currently doesn't build on ROCm and CK hasnt yet been added: +diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake +index c56773fe..0c0d349e 100644 +--- a/fbgemm_gpu/FbgemmGpu.cmake ++++ b/fbgemm_gpu/FbgemmGpu.cmake +@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources} + ################################################################################ + + set(fbgemm_gpu_sources_static_cpu +- codegen/training/forward/embedding_forward_split_cpu.cpp +- codegen/inference/embedding_forward_quantized_host_cpu.cpp +- codegen/training/backward/embedding_backward_dense_host_cpu.cpp +- codegen/utils/embedding_bounds_check_host_cpu.cpp +- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +- src/input_combine_ops/input_combine_cpu.cpp +- src/layout_transform_ops/layout_transform_ops_cpu.cpp ++ # codegen/training/forward/embedding_forward_split_cpu.cpp ++ # codegen/inference/embedding_forward_quantized_host_cpu.cpp ++ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp ++ # codegen/utils/embedding_bounds_check_host_cpu.cpp ++ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp ++ # src/input_combine_ops/input_combine_cpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_cpu.cpp + src/quantize_ops/quantize_ops_cpu.cpp + src/quantize_ops/quantize_ops_meta.cpp +- src/sparse_ops/sparse_ops_cpu.cpp +- src/sparse_ops/sparse_ops_meta.cpp +- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp +- src/split_embeddings_cache/linearize_cache_indices.cpp +- src/split_embeddings_cache/lfu_cache_populate_byte.cpp +- src/split_embeddings_cache/lru_cache_populate_byte.cpp +- src/split_embeddings_cache/lxu_cache.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++ # src/sparse_ops/sparse_ops_cpu.cpp ++ # src/sparse_ops/sparse_ops_meta.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp ++ # src/split_embeddings_cache/linearize_cache_indices.cpp ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lru_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lxu_cache.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++) + + if(NOT FBGEMM_CPU_ONLY) + list(APPEND fbgemm_gpu_sources_static_cpu +- codegen/inference/embedding_forward_quantized_host.cpp +- codegen/utils/embedding_bounds_check_host.cpp +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp +- src/layout_transform_ops/layout_transform_ops_gpu.cpp +- src/memory_utils/memory_utils.cpp +- src/memory_utils/memory_utils_ops.cpp +- src/memory_utils/memory_utils_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp ++ # codegen/inference/embedding_forward_quantized_host.cpp ++ # codegen/utils/embedding_bounds_check_host.cpp ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_gpu.cpp ++ # src/memory_utils/memory_utils.cpp ++ # src/memory_utils/memory_utils_ops.cpp ++ # src/memory_utils/memory_utils_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp + src/quantize_ops/quantize_ops_gpu.cpp +- src/sparse_ops/sparse_ops_gpu.cpp +- src/split_embeddings_utils/split_embeddings_utils.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cu +- src/metric_ops/metric_ops_host.cpp +- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp +- src/input_combine_ops/input_combine_gpu.cpp +- codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ # src/sparse_ops/sparse_ops_gpu.cpp ++ # src/split_embeddings_utils/split_embeddings_utils.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cu ++ # src/metric_ops/metric_ops_host.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp ++ # src/input_combine_ops/input_combine_gpu.cpp ++ # codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ ) + + if(NVML_LIB_PATH OR USE_ROCM) + message(STATUS "Adding merge_pooled_embeddings sources") +@@ -516,36 +518,36 @@ endif() + + if(NOT FBGEMM_CPU_ONLY) + set(fbgemm_gpu_sources_static_gpu +- codegen/utils/embedding_bounds_check.cu +- codegen/inference/embedding_forward_quantized_split_lookup.cu +- src/embedding_inplace_ops/embedding_inplace_update.cu +- src/histogram_binning_calibration_ops.cu +- src/input_combine_ops/input_combine.cu +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu +- src/memory_utils/memory_utils.cu +- src/memory_utils/memory_utils_ops.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu +- src/jagged_tensor_ops/dense_to_jagged_forward.cu +- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu +- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu +- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu +- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu +- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu +- src/jagged_tensor_ops/jagged_softmax_backward.cu +- src/jagged_tensor_ops/jagged_softmax_forward.cu +- src/jagged_tensor_ops/jagged_tensor_ops.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu +- src/jagged_tensor_ops/jagged_unique_indices.cu +- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +- src/layout_transform_ops/layout_transform_ops.cu +- src/metric_ops/metric_ops.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu +- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu ++ # codegen/utils/embedding_bounds_check.cu ++ # codegen/inference/embedding_forward_quantized_split_lookup.cu ++ # src/embedding_inplace_ops/embedding_inplace_update.cu ++ # src/histogram_binning_calibration_ops.cu ++ # src/input_combine_ops/input_combine.cu ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu ++ # src/memory_utils/memory_utils.cu ++ # src/memory_utils/memory_utils_ops.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu ++ # src/jagged_tensor_ops/dense_to_jagged_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu ++ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_softmax_backward.cu ++ # src/jagged_tensor_ops/jagged_softmax_forward.cu ++ # src/jagged_tensor_ops/jagged_tensor_ops.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu ++ # src/jagged_tensor_ops/jagged_unique_indices.cu ++ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu ++ # src/layout_transform_ops/layout_transform_ops.cu ++ # src/metric_ops/metric_ops.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu + src/quantize_ops/quantize_bfloat16.cu + src/quantize_ops/quantize_fp8_rowwise.cu + src/quantize_ops/quantize_fused_8bit_rowwise.cu +@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY) + src/quantize_ops/quantize_msfp.cu + src/quantize_ops/quantize_padded_fp8_rowwise.cu + src/quantize_ops/quantize_mx.cu +- src/sparse_ops/sparse_async_cumsum.cu +- src/sparse_ops/sparse_block_bucketize_features.cu +- src/sparse_ops/sparse_bucketize_features.cu +- src/sparse_ops/sparse_batched_unary_embeddings.cu +- src/sparse_ops/sparse_compute_frequency_sequence.cu +- src/sparse_ops/sparse_expand_into_jagged_permute.cu +- src/sparse_ops/sparse_group_index.cu +- src/sparse_ops/sparse_index_add.cu +- src/sparse_ops/sparse_index_select.cu +- src/sparse_ops/sparse_invert_permute.cu +- src/sparse_ops/sparse_pack_segments_backward.cu +- src/sparse_ops/sparse_pack_segments_forward.cu +- src/sparse_ops/sparse_permute_1d.cu +- src/sparse_ops/sparse_permute_2d.cu +- src/sparse_ops/sparse_permute102.cu +- src/sparse_ops/sparse_permute_embeddings.cu +- src/sparse_ops/sparse_range.cu +- src/sparse_ops/sparse_reorder_batched_ad.cu +- src/sparse_ops/sparse_segment_sum_csr.cu +- src/sparse_ops/sparse_zipf.cu +- src/split_embeddings_cache/lfu_cache_find.cu +- src/split_embeddings_cache/lfu_cache_populate.cu +- src/split_embeddings_cache/lfu_cache_populate_byte.cu +- src/split_embeddings_cache/lru_cache_find.cu +- src/split_embeddings_cache/lru_cache_populate.cu +- src/split_embeddings_cache/lru_cache_populate_byte.cu +- src/split_embeddings_cache/lxu_cache.cu +- src/split_embeddings_cache/linearize_cache_indices.cu +- src/split_embeddings_cache/reset_weight_momentum.cu +- src/split_embeddings_utils/generate_vbe_metadata.cu +- src/split_embeddings_utils/get_infos_metadata.cu +- src/split_embeddings_utils/radix_sort_pairs.cu +- src/split_embeddings_utils/transpose_embedding_input.cu) ++ # src/sparse_ops/sparse_async_cumsum.cu ++ # src/sparse_ops/sparse_block_bucketize_features.cu ++ # src/sparse_ops/sparse_bucketize_features.cu ++ # src/sparse_ops/sparse_batched_unary_embeddings.cu ++ # src/sparse_ops/sparse_compute_frequency_sequence.cu ++ # src/sparse_ops/sparse_expand_into_jagged_permute.cu ++ # src/sparse_ops/sparse_group_index.cu ++ # src/sparse_ops/sparse_index_add.cu ++ # src/sparse_ops/sparse_index_select.cu ++ # src/sparse_ops/sparse_invert_permute.cu ++ # src/sparse_ops/sparse_pack_segments_backward.cu ++ # src/sparse_ops/sparse_pack_segments_forward.cu ++ # src/sparse_ops/sparse_permute_1d.cu ++ # src/sparse_ops/sparse_permute_2d.cu ++ # src/sparse_ops/sparse_permute102.cu ++ # src/sparse_ops/sparse_permute_embeddings.cu ++ # src/sparse_ops/sparse_range.cu ++ # src/sparse_ops/sparse_reorder_batched_ad.cu ++ # src/sparse_ops/sparse_segment_sum_csr.cu ++ # src/sparse_ops/sparse_zipf.cu ++ # src/split_embeddings_cache/lfu_cache_find.cu ++ # src/split_embeddings_cache/lfu_cache_populate.cu ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cu ++ # src/split_embeddings_cache/lru_cache_find.cu ++ # src/split_embeddings_cache/lru_cache_populate.cu ++ # src/split_embeddings_cache/lru_cache_populate_byte.cu ++ # src/split_embeddings_cache/lxu_cache.cu ++ # src/split_embeddings_cache/linearize_cache_indices.cu ++ # src/split_embeddings_cache/reset_weight_momentum.cu ++ # src/split_embeddings_utils/generate_vbe_metadata.cu ++ # src/split_embeddings_utils/get_infos_metadata.cu ++ # src/split_embeddings_utils/radix_sort_pairs.cu ++ # src/split_embeddings_utils/transpose_embedding_input.cu) ++ ) + + set_source_files_properties(${fbgemm_gpu_sources_static_gpu} + PROPERTIES COMPILE_OPTIONS +diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +index 01f1d6ab..a6b8d7a8 100644 +--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt ++++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories + ${THIRDPARTY}/json/include + ${NCCL_INCLUDE_DIRS}) + +-set(attention_ops_sources +- src/attention/attention.cpp +- src/attention/gqa_attn_splitk.cu) ++# set(attention_ops_sources ++# src/attention/attention.cpp ++# src/attention/gqa_attn_splitk.cu) + + set(quantize_ops_sources + src/quantize/cutlass_extensions.cu + src/quantize/quantize.cu + src/quantize/quantize.cpp) + +-set(comm_ops_sources +- src/comm/car.cu +- src/comm/car.cpp) ++# set(comm_ops_sources ++# src/comm/car.cu ++# src/comm/car.cpp) + + set(experimental_gen_ai_cpp_source_files +- ${attention_ops_sources} ++ # ${attention_ops_sources} + ${quantize_ops_sources} +- ${comm_ops_sources}) ++ # ${comm_ops_sources} ++) + + set_source_files_properties(${experimental_gen_ai_cpp_source_files} + PROPERTIES INCLUDE_DIRECTORIES diff --git a/server/fix_torch90a.sh b/server/fix_torch90a.sh new file mode 100755 index 00000000..5e444828 --- /dev/null +++ b/server/fix_torch90a.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# This script is required to patch torch < 2.4 +# It adds the 90a cuda target (H100) +# This target is required to build FBGEMM kernels + +torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|') + +sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch +sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch +sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index fe839cf4..8ec2a5ae 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -8,6 +8,7 @@ from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.log import log_master app = typer.Typer() @@ -87,15 +88,17 @@ def serve( ) if len(lora_adapter_ids) > 0: - logger.warning( - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + log_master( + logger.warning, + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", ) # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # and warn the user if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: - logger.warning( - f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs." + log_master( + logger.warning, + f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", ) global CUDA_GRAPHS CUDA_GRAPHS = None diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 99c490d5..54da63e8 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -3,6 +3,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils.log import log_master from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -136,7 +137,10 @@ if ENGINE != "triton": try: import flash_attn_2_cuda - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + log_master( + logger.info, + "ROCm: using Flash Attention 2 Composable Kernel implementation.", + ) except ImportError as e: if major >= 8: architecture_suffix = f"-{SYSTEM}" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index 925b0b2d..aae2bd1a 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -4,19 +4,11 @@ from functools import lru_cache import bitsandbytes as bnb import torch from bitsandbytes.nn import Int8Params, Params4bit -from loguru import logger -from text_generation_server.utils.weights import Weight - - -@lru_cache(1) -def warn_deprecate_bnb(): - logger.warning( - "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" - ) +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class BNBWeight(Weight): +class BNBWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module): @dataclass -class BNBFP4Weight(Weight): +class BNBFP4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -90,7 +82,7 @@ class BNBFP4Weight(Weight): @dataclass -class BNBNF4Weight(Weight): +class BNBNF4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index f003f914..b1e5235a 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -2,11 +2,11 @@ from dataclasses import dataclass import torch from EETQ import quant_weights, w8_a16_gemm -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class EETQWeight(Weight): +class EETQWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index b56f568a..cdf16d6b 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,8 +1,29 @@ -from dataclasses import dataclass - import torch + +from dataclasses import dataclass +from typing import Optional, Union, List +from loguru import logger + from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import ( + Weight, + WeightsLoader, + UnquantizedWeight, + Weights, +) +from text_generation_server.utils.log import log_master, log_once + +FBGEMM_MM_AVAILABLE = False +FBGEMM_DYN_AVAILABLE = False +try: + import fbgemm_gpu.experimental.gen_ai + + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + FBGEMM_MM_AVAILABLE = major == 9 + FBGEMM_DYN_AVAILABLE = major >= 8 +except (ImportError, ModuleNotFoundError): + log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") def get_fp8_linear() -> torch.nn.Module: @@ -21,12 +42,17 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear -def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - device = weight.device +def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): + if FBGEMM_DYN_AVAILABLE: + qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( + weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype + ) + return qweight, scale + # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) @@ -38,27 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): return qweight, scale +class HybridFP8UnquantLoader(WeightsLoader): + """Weight loader that loads FP8 and unquantized Torch tensors.""" + + def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + self.activation_scale_ub = activation_scale_ub + self.to_fp8 = to_fp8 + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False + ) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + w = torch.cat(w, dim=dim) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [ + weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=0) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + @dataclass class Fp8Weight(Weight): weight: torch.Tensor + dtype: torch.dtype + weight_scale: Optional[torch.Tensor] = None + activation_scale_ub: Optional[float] = None def get_linear(self, bias: torch.Tensor): - return get_fp8_linear()(self.weight, bias) + if self.weight_scale is None: + return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear().from_fp8( + self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + ) class Fp8Linear(torch.nn.Module): def __init__( self, - weight, + qweight, + scale, + scale_upper_bound, bias, + dtype, ) -> None: super().__init__() - self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) + self.dtype = dtype + self.qweight = qweight + self.scale = scale + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) self.bias = bias if bias is not None else None + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scale = fp8_quantize(weight) + return cls( + qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + ) + + @classmethod + def from_fp8(cls, weight, scale, input_scale, bias, dtype): + return cls( + qweight=weight, + scale=scale, + scale_upper_bound=input_scale, + bias=bias, + dtype=dtype, + ) + def forward(self, input: torch.Tensor) -> torch.Tensor: + if FBGEMM_MM_AVAILABLE: + qinput, scale = fp8_quantize( + input, scale_upper_bound=self.scale_upper_bound + ) + + y = torch.ops.fbgemm.f8f8bf16_rowwise( + qinput, + self.qweight, + scale, + self.scale, + use_fast_accum=True, + bias=self.bias, + ) + return y.to(self.dtype) + qinput, scale = fp8_quantize(input) output, _ = torch._scaled_mm( qinput, diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 4d45822b..dc3b832f 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -9,11 +9,12 @@ from loguru import logger from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.utils.log import log_master try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error("exllamav2_kernels not installed.") + log_master(logger.warning, "exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a913ff57..40271c35 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module): def __init__( self, - weight: torch.Tensor, + qweight: torch.Tensor, + scale: torch.Tensor, bias: Optional[torch.Tensor], ) -> None: super().__init__() @@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - qweight, scale = fp8_quantize(weight) scale = scale.to(torch.float16) qweight, scales = repack_fp8_for_marlin(qweight, scale) @@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module): out_features // 64 * 16, dtype=torch.int, device=qweight.device ) + @classmethod + def from_unquant(cls, weight, bias, _dtype): + qweight, scale = fp8_quantize(weight) + return cls(qweight=qweight, scale=scale, bias=bias) + + @classmethod + def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): + return cls(qweight=weight, scale=scale, bias=bias) + def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 690a8887..a43cdfed 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( ) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_master # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -47,9 +48,7 @@ torch.set_grad_enabled(False) __all__ = [ "Model", - "BLOOMSharded", "CausalLM", - "GalacticaSharded", "Seq2SeqLM", "get_model", ] @@ -125,7 +124,7 @@ try: ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False @@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True try: from text_generation_server.models.mamba import Mamba except ImportError as e: - logger.warning(f"Could not import Mamba: {e}") + log_master(logger.warning, f"Could not import Mamba: {e}") MAMBA_AVAILABLE = False if MAMBA_AVAILABLE: @@ -311,6 +310,12 @@ def get_model( if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 + elif quantize == "fp8": + from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + + if FBGEMM_MM_AVAILABLE: + # fbgemm kernels are fp8xfp8->bf16 + dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype. @@ -433,7 +438,9 @@ def get_model( speculate = get_speculate() if speculate > 0: - logger.info(f"Using speculation {method} with {speculate} input ids.") + log_master( + logger.info, f"Using speculation {method} with {speculate} input ids." + ) if model_type is None: # TODO: fix how we determine model type for Mamba @@ -448,10 +455,10 @@ def get_model( if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: - logger.info(f"Auto selecting quantization method {method}") + log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method else: - logger.info(f"Unknown quantization method {method}") + log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( @@ -593,7 +600,7 @@ def get_model( ) except RuntimeError as e: # Lots of legacy models with various weight names. - logger.warning(f"Couldn't load flash gpt2 variant: {e}") + log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") return CausalLM.fallback( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5237a484..f7980d2d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -42,16 +41,15 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import ( - DefaultWeightsLoader, UnquantizedWeight, Weights, ) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader if SYSTEM == "rocm": try: @@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id): @contextmanager def no_fp8(weights: Weights): + """De-activate fp8 auto conversion for the duration of this context manager""" weights_loader = weights.weights_loader - if ( - isinstance(weights_loader, DefaultWeightsLoader) - and weights_loader.weight_class is Fp8Weight - ): - weights_loader = DefaultWeightsLoader(UnquantizedWeight) + if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8: + weights_loader = HybridFP8UnquantLoader( + weights_loader.activation_scale_ub, to_fp8=False + ) with weights.use_loader(weights_loader): yield @@ -418,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.layers = nn.ModuleList( + + # Skip fp8 quant for first and last layers + self.layers = nn.ModuleList() + with no_fp8(weights): + self.layers.append( + FlashLlamaLayer( + index=0, + prefix=( + "model.layers.0" if not prefix else "{prefix}.model.layers.0" + ), + config=config, + weights=weights, + ) + ) + + self.layers.extend( [ FlashLlamaLayer( index=layer_id, @@ -430,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module): config=config, weights=weights, ) - for layer_id in range(config.num_hidden_layers) + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1) ] ) + + with no_fp8(weights): + last_layer_id = config.num_hidden_layers - 1 + self.layers.append( + FlashLlamaLayer( + index=last_layer_id, + prefix=( + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}.model.layers.{last_layer_id}" + ), + config=config, + weights=weights, + ) + ) + self.norm = FastRMSNorm.load( prefix="model.norm" if not prefix else f"{prefix}.model.norm", weights=weights, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2888f1f7..cfffafa1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model +from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - hub, ) from text_generation_server.models.types import ( Batch, @@ -1156,31 +1155,36 @@ class FlashCausalLM(Model): f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) - logger.info( - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." + log_master( + logger.info, + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) if os.path.isfile(tunableop_filepath): - logger.info( - f"The file {tunableop_filepath} already exists and will be reused." + log_master( + logger.info, + f"The file {tunableop_filepath} already exists and will be reused.", ) torch.cuda.tunable.read_file(tunableop_filepath) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) for seqlen in tuning_sequences: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") + log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) else: - logger.info( - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." + log_master( + logger.info, + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) if CUDA_GRAPHS: try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") + log_master( + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + ) # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: @@ -1188,7 +1192,9 @@ class FlashCausalLM(Model): except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") + log_master( + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + ) return int(num_blocks * BLOCK_SIZE) @@ -1540,8 +1546,7 @@ class FlashCausalLM(Model): left = 0 if n_accepted_ids > 1: - if RANK == 0: - logger.debug(f"Speculated ids {n_accepted_ids - 1}") + log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 06035ccd..ac42df30 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,15 +1,16 @@ import torch import os from loguru import logger -from typing import Dict +from typing import Dict, Optional + +from text_generation_server.utils.log import log_master MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: - logger.info("Using FLASH_DECODING") - + log_master(logger.info, "Using FLASH_DECODING") cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: @@ -26,11 +27,9 @@ else: if cuda_graphs is not None: cuda_graphs.sort(reverse=True) - CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. -global MODEL_ID MODEL_ID = None @@ -41,8 +40,7 @@ def set_model_id(model_id: str): # NOTE: eventually we should move this into the router and pass back the # index in all cases. -global ADAPTER_TO_INDEX -ADAPTER_TO_INDEX: Dict[str, int] = None +ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None def set_adapter_to_index(adapter_to_index: Dict[str, int]): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 09130b85..e7748bb9 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import ( AdapterParameters, AdapterSource, ) +from text_generation_server.utils.log import log_master from loguru import logger @@ -204,8 +205,9 @@ class Model(ABC): f"order to use the dynamic adapter loading feature." ) - logger.info( - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + log_master( + logger.info, + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", ) weight_names = tuple([v[0] for v in self.target_to_layer.values()]) ( @@ -240,8 +242,9 @@ class Model(ABC): layer_weights.add_adapter(adapter_index, adapter_weights) if len(unused_weight_names) > 0: - logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + log_master( + logger.warning, + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", ) if adapter_tokenizer is not None: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f869f8b5..308d5a3d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,4 +1,3 @@ -from itertools import repeat import torch from PIL import Image from io import BytesIO @@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils.log import log_master from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info( - f"Found {num_features} features in image of resolution {height}x{width}" + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", ) return "" * num_features diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 36d63e86..82aeba6c 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -56,7 +56,7 @@ def initialize_torch_distributed(): backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) + options._timeout = timedelta(seconds=120) else: backend = "gloo" options = None @@ -76,7 +76,7 @@ def initialize_torch_distributed(): backend="ccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: @@ -84,7 +84,7 @@ def initialize_torch_distributed(): backend=backend, world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py index b1456f1e..4385c71e 100644 --- a/server/text_generation_server/utils/log.py +++ b/server/text_generation_server/utils/log.py @@ -1,6 +1,15 @@ from functools import lru_cache +from text_generation_server.utils.dist import RANK @lru_cache(10) -def log_once(log, msg: str): - log(msg) +def log_once(log, msg: str, master=True): + if master: + log_master(log, msg) + else: + log(msg) + + +def log_master(log, msg: str): + if RANK == 0: + log(msg) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index e8e22db8..c3c038fe 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -11,6 +11,7 @@ from text_generation_server.utils.weights import ( ) +# TODO: Split this config to have a single config type per quant method @dataclass class _QuantizerConfig: bits: int @@ -21,6 +22,11 @@ class _QuantizerConfig: sym: bool +@dataclass +class _FP8QuantizerConfig: + activation_scale_ub: float + + # We should probably do this with Pytantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): @@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision): filename = hf_hub_download(model_id, filename=filename, revision=revision) with open(filename, "r") as f: data = json.load(f) + + # FP8 config + if data["quantization_config"]["quant_method"] == "fbgemm_fp8": + return _FP8QuantizerConfig( + activation_scale_ub=data["quantization_config"]["activation_scale_ub"] + ) + bits = data["quantization_config"]["bits"] groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models @@ -99,6 +112,12 @@ def get_loader( if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return GPTQWeightsLoader( bits=quantizer_config.bits, desc_act=quantizer_config.desc_act, @@ -127,18 +146,28 @@ def get_loader( from text_generation_server.layers.exl2 import Exl2WeightsLoader return Exl2WeightsLoader() - elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Weight - - return DefaultWeightsLoader(Fp8Weight) elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return MarlinWeightsLoader( bits=quantizer_config.bits, is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) - elif quantize is None: - return DefaultWeightsLoader(UnquantizedWeight) + elif quantize == "fp8" or quantize is None: + from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + + # Since the default for the quantize config is _QuantizerConfig, + # we need to add this check to not get an attribute error + activation_scale_ub = None + if isinstance(quantizer_config, _FP8QuantizerConfig): + activation_scale_ub = quantizer_config.activation_scale_ub + + return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 91592df0..66bb6051 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,12 +1,12 @@ +import torch + from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass -from enum import Enum, auto from pathlib import Path -from typing import Dict, List, Optional, Union - -import torch +from typing import Dict, List, Optional, Union, Type from safetensors import safe_open +from dataclasses import dataclass + from text_generation_server.utils.import_utils import SYSTEM @@ -84,7 +84,7 @@ class Weight(ABC): @dataclass -class UnquantizedWeight: +class UnquantizedWeight(Weight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -99,7 +99,7 @@ class UnquantizedWeight: class DefaultWeightsLoader(WeightsLoader): """Weight loader that loads (unquantized) Torch tensors.""" - def __init__(self, weight_class): + def __init__(self, weight_class: Type[UnquantizedWeight]): """Create a loader. Weights will be wrapped using the given `weights_class`, normally this will be `UnquantizedWeight`, but a quantizer-specific class such as `Fp8Weight` can be used to quantize the weights during loading. @@ -208,20 +208,29 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True): + def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 - # as well. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + # as well. FP8 uses torch.float8_e4m3fn + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -241,12 +250,16 @@ class Weights: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. - if tensor.dtype not in (torch.int16, torch.int32): + # FP8 uses torch.float8_e4m3fn. + if ( + tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -255,10 +268,14 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) def get_packed_sharded( - self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + self, + tensor_name: str, + dim: int, + block_sizes: Union[int, List[int]], + to_dtype=True, ) -> torch.Tensor: """ Get a shard from a tensor that packs multiple tensors. @@ -304,7 +321,16 @@ class Weights: tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) return tensor