diff --git a/.gitignore b/.gitignore index 2a771499..7d6c7564 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,5 @@ server/fbgemmm # Gaudi auto-generated files hl-smi_log*.txt +.graph_dumps +out diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index ba5d6ec3..91b172ed 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -1,3 +1,7 @@ +# Those arguments are required to build the image +ARG HABANA_VERSION +ARG PYTORCH_VERSION + # Rust builder FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src @@ -48,7 +52,10 @@ COPY launcher launcher RUN cargo build --profile release-opt # Text Generation Inference base image -FROM vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest AS base +ARG HABANA_VERSION +ARG PYTORCH_VERSION + +FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base ENV ATTENTION=default ENV PREFIX_CACHING=0 @@ -80,12 +87,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins COPY proto proto COPY backends/gaudi/server server COPY backends/gaudi/server/Makefile server/Makefile +ARG HABANA_VERSION RUN cd server && \ make gen-server && \ pip install --no-deps -r requirements.txt && \ bash ./dill-0.3.8-patch.sh && \ pip install outlines~=0.0.34 && \ - pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0 && \ + pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir @@ -108,6 +116,10 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base +ENV HF_HUB_ENABLE_HF_TRANSFER 1 +ENV HABANA_VISIBLE_DEVICES all +ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE + COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index 8162972d..ce3be25d 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -2,34 +2,33 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) root_dir := "${mkfile_dir}/../.." +HABANA_VERSION := 1.19.0 +PYTORCH_VERSION := 2.5.1 + .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install image: - docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} + docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) run-local-dev-container: docker run -it \ --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ - -e LOG_LEVEL=debug \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - -e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \ - -e ENABLE_HPU_GRAPH=true \ - -e LIMIT_HPU_GRAPH=true \ - -e USE_FLASH_ATTENTION=true \ - -e FLASH_ATTENTION_RECOMPUTE=true \ - -e PORT=8080 \ + --ipc=host \ --cap-add=sys_nice \ --net=host \ - --ipc=host \ + -e HABANA_VISIBLE_DEVICES=all \ + -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ + -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ + -e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \ + -e LOG_LEVEL=debug \ + -e PORT=8080 \ -v /home/ubuntu/.cache/huggingface:/data \ -v $(PWD):/text-generation-inference \ -w /text-generation-inference \ - vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest + vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest install-dependencies: - pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0 + pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION) pip install outlines~=0.0.34 curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/backends/gaudi/README.md b/backends/gaudi/README.md index c2bff110..765ad447 100644 --- a/backends/gaudi/README.md +++ b/backends/gaudi/README.md @@ -6,7 +6,7 @@ This is the TGI backend for Intel Gaudi. This backend is composed of the tgi ser ## Build your own image -The simplest way to build TGI with the gaudi backend is to use the provided `Makefile`: +The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`: Option 1: From the project root directory: ```bash @@ -20,25 +20,39 @@ make image ``` You can now run the server with the following command: + +Option 1: Sharded: ```bash model=meta-llama/Llama-3.1-8B-Instruct hf_token=$(cat ${HOME}/.cache/huggingface/token) volume=${HOME}/.cache/huggingface -docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ --e LOG_LEVEL=debug \ --e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ --e HF_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true \ --e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice \ ---ipc=host tgi-gaudi --model-id $model --sharded true \ ---num-shard 8 --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048 --max-batch-total-tokens 8192 +docker run --runtime=habana --ipc=host --cap-add=sys_nice \ + -p 8080:80 -v $volume:/data \ + -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \ + tgi-gaudi --model-id $model \ + --sharded true --num-shard 8 \ + --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048 +``` + +Option 2: Non-sharded: +```bash +model=meta-llama/Llama-3.1-8B-Instruct +hf_token=$(cat ${HOME}/.cache/huggingface/token) +volume=${HOME}/.cache/huggingface + +docker run --runtime=habana --ipc=host --cap-add=sys_nice \ + -p 8080:80 -v $volume:/data \ + -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \ + tgi-gaudi --model-id $model \ + --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048 ``` ## Contributing ### Local Development -This is useful if you want to run the server in locally for better debugging. +This is useful if you want to run the server locally for better debugging. ```bash make -C backends/gaudi run-local-dev-container ``` diff --git a/backends/gaudi/server/Makefile-exllamav2 b/backends/gaudi/server/Makefile-exllamav2 deleted file mode 100644 index 0d4cc385..00000000 --- a/backends/gaudi/server/Makefile-exllamav2 +++ /dev/null @@ -1,12 +0,0 @@ -exllamav2_commit := v0.1.8 - -build-exllamav2: - git clone https://github.com/turboderp/exllamav2.git exllamav2 && \ - cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \ - git submodule update --init --recursive && \ - pip install -r requirements.txt && \ - CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build - -install-exllamav2: build-exllamav2 - cd exllamav2/ && \ - CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install diff --git a/backends/gaudi/server/Makefile-flashinfer b/backends/gaudi/server/Makefile-flashinfer deleted file mode 100644 index f0a27622..00000000 --- a/backends/gaudi/server/Makefile-flashinfer +++ /dev/null @@ -1,2 +0,0 @@ -install-flashinfer: - pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 diff --git a/backends/gaudi/server/Makefile-lorax-punica b/backends/gaudi/server/Makefile-lorax-punica deleted file mode 100644 index 72f06f76..00000000 --- a/backends/gaudi/server/Makefile-lorax-punica +++ /dev/null @@ -1,12 +0,0 @@ -lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc - -build-lorax-punica: - if [ ! -d 'lorax-punica' ]; then \ - git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ - fi - cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) - cd lorax-punica && git submodule update --init --recursive - cd lorax-punica/server/punica_kernels && python setup.py build - -install-lorax-punica: build-lorax-punica - cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/backends/gaudi/server/custom_kernels/custom_kernels/fused_attention_cuda.cu b/backends/gaudi/server/custom_kernels/custom_kernels/fused_attention_cuda.cu deleted file mode 100644 index 60f9f028..00000000 --- a/backends/gaudi/server/custom_kernels/custom_kernels/fused_attention_cuda.cu +++ /dev/null @@ -1,250 +0,0 @@ -#include -#include -#include -#include -#include - -#include - -/** -* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda -* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu -**/ - -// Available in pytorch main -//#define DISPATCH_CASE_FLOATING_TYPES(...) \ -// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ -// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ -// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ -// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ - -/* -* Forward passes -*/ - -/** -* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype -**/ -template -__global__ void forward_masked_softmax_kernel( - const torch::PackedTensorAccessor32 attention_scores, // [B, KV] - const torch::PackedTensorAccessor32 mask, // [B, KV] - torch::PackedTensorAccessor32 result, // [B, KV] - const int64_t effective_kv_length, - const dim3 blockDim, - const int64_t rows_per_block, - const int64_t kv_length, - const int64_t batch_size -) { - const auto row_id = threadIdx.x / effective_kv_length; - const auto effective_kv_length_id = threadIdx.x % effective_kv_length; - const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; - auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; - kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; - const auto kv_length_end = kv_length_end_; - - const auto batch_id = blockIdx.x * rows_per_block + row_id; - - // We need 2 float storage for each row, one for max computation, the other for normalizing exponential - extern __shared__ float temp_storage[]; - const auto row_id_mem_offset = row_id * 2; - if (effective_kv_length_id == 0) { - temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); - temp_storage[row_id_mem_offset + 1] = 0; - } - __syncthreads(); - - // Compute mask and max - if (batch_id < batch_size) { - float thread_max = -std::numeric_limits::infinity(); - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - if (mask[batch_id][kv_length_id] == 0) { - const float candidate = attention_scores[batch_id][kv_length_id]; - thread_max = (thread_max < candidate) ? candidate : thread_max; - } - } - if (thread_max != -std::numeric_limits::infinity()) { - // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot - gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); - } - } - - __syncthreads(); - - // Compute exp(elt - max) masked - float exponential[min_kv_length_shard_size_per_thread]; - if (batch_id < batch_size) { - float thread_add = 0; - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - if (mask[batch_id][kv_length_id] == 0) { - exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); - thread_add = thread_add + exponential[kv_length_id - kv_length_start]; - } else { - exponential[kv_length_id - kv_length_start] = 0.; - } - } - if (thread_add > 0) { - // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot - gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); - } - } - - __syncthreads(); - - // Compute softmax - if (batch_id < batch_size) { - // If sum of all exponential is 0, we set the softmax values to 0 - if (temp_storage[row_id_mem_offset + 1] == 0.) { - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - result[batch_id][kv_length_id] = 0.; - } - } else { - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); - } - } - } -} - -#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::tuple>, at::Tensor> forward( - const at::Tensor query, - const at::Tensor key, - const at::Tensor value, - const std::optional> layer_past, - const at::Tensor attention_mask, - const std::optional head_mask, - const float inv_norm_factor, - const int num_heads, - const bool use_cache -) { - auto query_layer = query; - auto key_layer = key; - auto value_layer = value; - - if (layer_past) { - const auto past_key = (*layer_past).at(0); - const auto past_value = (*layer_past).at(1); - key_layer = at::cat({past_key, key_layer}, 2); - value_layer = at::cat({past_value, value_layer}, 2); - } - - std::optional> present; - if (use_cache) { - present = {key_layer, value_layer}; - } else { - present = {}; - } - - const auto batch_size = query_layer.size(0); - const auto q_length = query_layer.size(2); - const auto attn_head_size = query_layer.size(3); - const auto batch_size_times_num_heads = batch_size * num_heads; - const auto kv_length = key_layer.size(2); - - const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size}); - auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2); - auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}); - - auto query_scaled = query_view * inv_norm_factor; - auto attention_scores = at::bmm(query_scaled, key_view); - - // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` - at::Tensor attention_probs; - if (true) { - // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors - const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); - const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); - - // Custom kernel - attention_probs = at::empty_like(attention_scores_2d); - - // Check that inputs and contiguous + cuda tensors - CHECK_INPUT(attention_scores_2d); - CHECK_INPUT(attention_mask_2d); - - // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out - // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { - /* - * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ - * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf - * - SMs: 108 - * - TPCs: 56 (What's that?) - * - Memory size: 40 GB - * - L2 Cache size: 40960 KB (shared across all SMs) - * - L1/Shared memory size: 192 KB (shared across all threads within a SM) - * - Max Threads / SM: 2048 - * - Max Thread Blocks / SM: 32 - */ - - /* - * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block - * with multiple threads as we need to `sync_threads` to run exponential sum. - * We maximise the usage of threads within a single block - */ - // TODO @thomasw21 figure out everything warp related: - // - why do they have to be power of 2 - // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 - const auto MAX_THREADS_PER_SM = 1024; - // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` - const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; - // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` - const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; - const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; - const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; - - const dim3 gridDim(num_blocks); // Number of blocks that run - const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block - const int shared_mem_forward = rows_per_block * 2 * sizeof(float); - - // 192 * 2 ** 10 - // const auto MAX_L1_MEMORY = 196608; - // const auto MAX_SMs = 108; - // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); - // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); - // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); - - forward_masked_softmax_kernel<<>>( - attention_scores_2d.packed_accessor32(), - attention_mask_2d.packed_accessor32(), - attention_probs.packed_accessor32(), - effective_kv_length, - blockDim, - rows_per_block, - kv_length, - batch_size_times_num_heads * q_length - ); - }); - attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); - } else { - // Pytorch C++ API - auto input_dtype = attention_scores.scalar_type(); - if (input_dtype == at::ScalarType::Float) { - attention_scores = attention_scores.to(at::ScalarType::Float); - }; - // TODO @thomasw21 Figure out how to get minimum value - auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); - attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); - } - - auto context_layer = attention_probs.bmm(value_view); - - // `_merge_heads` - context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size}); - context_layer = context_layer.permute({0, 2, 1, 3}); - context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads}); - - return std::make_tuple(context_layer, present, attention_probs); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "forward", - &forward, - "GPT-Neox attention mechanism forward (CUDA)" - ); -} diff --git a/backends/gaudi/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu b/backends/gaudi/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu deleted file mode 100644 index 8206c3e0..00000000 --- a/backends/gaudi/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu +++ /dev/null @@ -1,250 +0,0 @@ -#include -#include -#include -#include -#include - -#include - -/** -* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda -* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu -**/ - -// Available in pytorch main -//#define DISPATCH_CASE_FLOATING_TYPES(...) \ -// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ -// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ -// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ -// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ - -/* -* Forward passes -*/ - -/** -* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype -**/ -template -__global__ void forward_masked_softmax_kernel( - const torch::PackedTensorAccessor32 attention_scores, // [B, KV] - const torch::PackedTensorAccessor32 mask, // [B, KV] - torch::PackedTensorAccessor32 result, // [B, KV] - const int64_t effective_kv_length, - const dim3 blockDim, - const int64_t rows_per_block, - const int64_t kv_length, - const int64_t batch_size -) { - const auto row_id = threadIdx.x / effective_kv_length; - const auto effective_kv_length_id = threadIdx.x % effective_kv_length; - const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; - auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; - kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; - const auto kv_length_end = kv_length_end_; - - const auto batch_id = blockIdx.x * rows_per_block + row_id; - - // We need 2 float storage for each row, one for max computation, the other for normalizing exponential - extern __shared__ float temp_storage[]; - const auto row_id_mem_offset = row_id * 2; - if (effective_kv_length_id == 0) { - temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); - temp_storage[row_id_mem_offset + 1] = 0; - } - __syncthreads(); - - // Compute mask and max - if (batch_id < batch_size) { - float thread_max = -std::numeric_limits::infinity(); - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - if (mask[batch_id][kv_length_id] == 0) { - const float candidate = attention_scores[batch_id][kv_length_id]; - thread_max = (thread_max < candidate) ? candidate : thread_max; - } - } - if (thread_max != -std::numeric_limits::infinity()) { - // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot - gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); - } - } - - __syncthreads(); - - // Compute exp(elt - max) masked - float exponential[min_kv_length_shard_size_per_thread]; - if (batch_id < batch_size) { - float thread_add = 0; - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - if (mask[batch_id][kv_length_id] == 0) { - exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); - thread_add = thread_add + exponential[kv_length_id - kv_length_start]; - } else { - exponential[kv_length_id - kv_length_start] = 0.; - } - } - if (thread_add > 0) { - // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot - gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); - } - } - - __syncthreads(); - - // Compute softmax - if (batch_id < batch_size) { - // If sum of all exponential is 0, we set the softmax values to 0 - if (temp_storage[row_id_mem_offset + 1] == 0.) { - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - result[batch_id][kv_length_id] = 0.; - } - } else { - for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { - result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); - } - } - } -} - -#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::tuple>, at::Tensor> forward( - const at::Tensor fused_qkv, - const std::optional> layer_past, - const at::Tensor alibi, - const at::Tensor attention_mask, - const std::optional head_mask, - const float beta, - const float inv_norm_factor, - const int num_heads, - const bool use_cache -) { - const auto batch_size = fused_qkv.size(0); - const auto q_length = fused_qkv.size(1); - const auto three_times_hidden_size = fused_qkv.size(2); - const auto head_dim = three_times_hidden_size / (3 * num_heads); - const auto batch_size_times_num_heads = batch_size * num_heads; - - // `split_heads` - const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim}); - const auto tensor_list = fused_qkv_view.split(head_dim, -1); - const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); - auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length}); - auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); - - if (layer_past) { - const auto past_key = (*layer_past).at(0); - const auto past_value = (*layer_past).at(1); - key_layer = at::cat({past_key, key_layer}, 2); - value_layer = at::cat({past_value, value_layer}, 1); - } - - std::optional> present; - if (use_cache) { - present = {key_layer, value_layer}; - } else { - present = {}; - } - - auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor); - - // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` - at::Tensor attention_probs; - if (true) { - const auto kv_length = key_layer.size(2); - - // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors - const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); - const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); - - // Custom kernel - attention_probs = at::empty_like(attention_scores_2d); - - // Check that inputs and contiguous + cuda tensors - CHECK_INPUT(attention_scores_2d); - CHECK_INPUT(attention_mask_2d); - - // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out - // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { - /* - * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ - * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf - * - SMs: 108 - * - TPCs: 56 (What's that?) - * - Memory size: 40 GB - * - L2 Cache size: 40960 KB (shared across all SMs) - * - L1/Shared memory size: 192 KB (shared across all threads within a SM) - * - Max Threads / SM: 2048 - * - Max Thread Blocks / SM: 32 - */ - - /* - * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block - * with multiple threads as we need to `sync_threads` to run exponential sum. - * We maximise the usage of threads within a single block - */ - // TODO @thomasw21 figure out everything warp related: - // - why do they have to be power of 2 - // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 - const auto MAX_THREADS_PER_SM = 1024; - // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` - const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; - // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` - const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; - const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; - const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; - - const dim3 gridDim(num_blocks); // Number of blocks that run - const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block - const int shared_mem_forward = rows_per_block * 2 * sizeof(float); - - // 192 * 2 ** 10 - // const auto MAX_L1_MEMORY = 196608; - // const auto MAX_SMs = 108; - // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); - // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); - // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); - - forward_masked_softmax_kernel<<>>( - attention_scores_2d.packed_accessor32(), - attention_mask_2d.packed_accessor32(), - attention_probs.packed_accessor32(), - effective_kv_length, - blockDim, - rows_per_block, - kv_length, - batch_size_times_num_heads * q_length - ); - }); - attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); - } else { - // Pytorch C++ API - auto input_dtype = attention_scores.scalar_type(); - if (input_dtype == at::ScalarType::Float) { - attention_scores = attention_scores.to(at::ScalarType::Float); - }; - // TODO @thomasw21 Figure out how to get minimum value - auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); - attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); - } - - auto context_layer = attention_probs.bmm(value_layer); - - // `_merge_heads` - context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim}); - context_layer = context_layer.permute({0, 2, 1, 3}); - context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3}); - - return std::make_tuple(context_layer, present, attention_probs); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "forward", - &forward, - "Bloom attention mechanism forward (CUDA)" - ); -} diff --git a/backends/gaudi/server/custom_kernels/setup.py b/backends/gaudi/server/custom_kernels/setup.py deleted file mode 100644 index e0b83987..00000000 --- a/backends/gaudi/server/custom_kernels/setup.py +++ /dev/null @@ -1,21 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -extra_compile_args = ["-std=c++17"] - -setup( - name="custom_kernels", - ext_modules=[ - CUDAExtension( - name="custom_kernels.fused_bloom_attention_cuda", - sources=["custom_kernels/fused_bloom_attention_cuda.cu"], - extra_compile_args=extra_compile_args, - ), - CUDAExtension( - name="custom_kernels.fused_attention_cuda", - sources=["custom_kernels/fused_attention_cuda.cu"], - extra_compile_args=extra_compile_args, - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cu_compat.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cu_compat.cuh deleted file mode 100644 index c5258813..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cu_compat.cuh +++ /dev/null @@ -1,58 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_compat_cuh -#define _cuda_compat_cuh - -// atomicAdd for half types, to support CC < 7.x - -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); -} - -// atomicAdd for half2 types - -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); -} - -// - -#if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) - -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } - -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif - -#endif -#endif - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cu deleted file mode 100644 index ee2cbee2..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cu +++ /dev/null @@ -1,71 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#define _cuda_buffers_cu -#include "cuda_buffers.cuh" - -CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; -// __constant__ half2 q4_table[16][256]; -// half2 q4_table_host[16][256]; -// bool q4_table_init = false; - -CudaBuffers::CudaBuffers -( - int _device, - half* _temp_state, - half* _temp_dq -) : - device(_device), - temp_state(_temp_state), - temp_dq(_temp_dq) -{ - cudaSetDevice(_device); - - cudaStreamCreate(&alt_stream_1); - cudaStreamCreate(&alt_stream_2); - cudaStreamCreate(&alt_stream_3); - cudaEventCreate(&alt_stream_1_done); - cudaEventCreate(&alt_stream_2_done); - cudaEventCreate(&alt_stream_3_done); -} - -CudaBuffers::~CudaBuffers() -{ - cudaStreamDestroy(alt_stream_1); - cudaStreamDestroy(alt_stream_2); - cudaStreamDestroy(alt_stream_3); - cudaEventDestroy(alt_stream_1_done); - cudaEventDestroy(alt_stream_2_done); - cudaEventDestroy(alt_stream_3_done); -} - -CudaBuffers* get_buffers(const int device_index) -{ - return g_buffers[device_index]; -} - -void prepare_buffers_cuda -( - int _device, - half* _temp_state, - half* _temp_dq -) -{ - CudaBuffers* buffers = new CudaBuffers - ( - _device, - _temp_state, - _temp_dq - ); - - g_buffers[_device] = buffers; -} - -void cleanup_buffers_cuda() -{ - for (int i = 0; i < CUDA_MAX_DEVICES; i++) - { - if (!g_buffers[i]) continue; - delete g_buffers[i]; - g_buffers[i] = NULL; - } -} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh deleted file mode 100644 index afb60a01..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh +++ /dev/null @@ -1,52 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_buffers_cuh -#define _cuda_buffers_cuh - -#include -#include -#include -#include - -const int CUDA_MAX_DEVICES = 16; - -// #ifndef _cuda_buffers_cu -// extern __constant__ half2 q4_table[16][256]; -// #endif - -class CudaBuffers -{ -public: - int device; - - half* temp_state; // [max_hidden_rows * intermediate_size] - half* temp_dq; // size of largest quant tensor * 8 - - cudaStream_t alt_stream_1; - cudaStream_t alt_stream_2; - cudaStream_t alt_stream_3; - cudaEvent_t alt_stream_1_done; - cudaEvent_t alt_stream_2_done; - cudaEvent_t alt_stream_3_done; - - CudaBuffers - ( - int _device, - half* _temp_state, - half* _temp_dq - ); - ~CudaBuffers(); -}; - -CudaBuffers* get_buffers(const int device_index); - -void prepare_buffers_cuda -( - int _device, - half* _temp_state, - half* _temp_dq -); - -void cleanup_buffers_cuda(); - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu deleted file mode 100644 index c25b0206..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu +++ /dev/null @@ -1,61 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "column_remap.cuh" -#include "../util.cuh" - -const int SHUF_BLOCKSIZE_X = 256; -const int SHUF_BLOCKSIZE_Y = 16; - -__global__ void column_remap_kernel -( - const half* __restrict__ x, - half* __restrict__ x_new, - const int x_width, - const int x_height, - const uint32_t* x_map -) -{ - int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; - - int x_stride = x_width; - int x_idx = x_row * x_stride + x_column; - - int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); - int x_idx_end = x_row_end * x_stride + x_column; - - int s_column = x_map[x_column]; - int s_idx = x_row * x_stride + s_column; - - while (x_idx < x_idx_end) - { - x_new[x_idx] = x[s_idx]; - x_idx += x_stride; - s_idx += x_stride; - } -} - -// Remap columns in x to correspond to sequential group index before matmul -// -// perform x -> seq_x such that seq_x @ seq_w == x @ w - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -) -{ - dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); - - dim3 blocks - ( - (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, - (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, - 1 - ); - - column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); -} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh deleted file mode 100644 index 0364e38c..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh +++ /dev/null @@ -1,19 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _column_remap_cuh -#define _column_remap_cuh - -#include -#include -#include - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -); - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu deleted file mode 100644 index 1b0f7956..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu +++ /dev/null @@ -1,256 +0,0 @@ -#include "q4_matmul.cuh" -#include "column_remap.cuh" -#include -#include "../util.cuh" -#include "../matrix.cuh" -#include "../cu_compat.cuh" -#include "../cuda_buffers.cuh" -#if defined(USE_ROCM) -#include "../hip_compat.cuh" -#endif - -const int THREADS_X = 32; // Block size and thread count along columns in w and out -const int THREADS_Y = 1; // Block size and thread count along rows in x and out - -typedef void (*fp_q4_matmul_kernel) -( - const half*, - const uint32_t*, - half*, - const half*, - const uint32_t*, - const int, - const int, - const int, - const int, - const int, - const uint32_t*, - bool -); - -template -__global__ void q4_matmul_kernel -( - const half* __restrict__ x, - const uint32_t* __restrict__ w, - half* __restrict__ out, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int dim, - const int width, - const int groupsize, - const int block_size_z, - const uint32_t* __restrict__ x_map, - bool no_zero -) -{ - // Start of block - - int x_column = block_size_z * blockIdx.z; - int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - int x_row = THREADS_Y * blockIdx.y + threadIdx.y; - - int iterations = (x_column_end - x_column) / 8; - - // Views - - MatrixView_half x_(x, height, dim); - MatrixView_half w_scales_(w_scales, dim / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); - MatrixView_q4_column w_(w, dim, width); - MatrixView_half_rw out_(out, height, width); - - // Zero output - - if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) - { - *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; - __syncthreads(); - } - - // Loop over part of x row (and w column) - - half2 acc = {}; - half acc_h = {}; - - if constexpr (use_groupsize) - { - // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this - // could be slightly faster - - for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) - { - if constexpr (use_half2) - { - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - else - { - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - } - } - else - { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - - for (int k = x_column; k < x_column + iterations * 8; k += 8) - { - if constexpr (use_half2) - { - int group = k / groupsize; - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - else - { - int group = k / groupsize; - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - } - } - - // Add to block result - - if constexpr (use_half2) - { - half result = __hadd(__low2half(acc), __high2half(acc)); - atomicAdd(out_.item_ptr(x_row, w_column), result); - } - else - { - atomicAdd(out_.item_ptr(x_row, w_column), acc_h); - } -} - -fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) -{ - // - if (tuningParams->matmul_no_half2) { - if (block_size_z % groupsize == 0) { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } else { - if (block_size_z % groupsize == 0) - { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } -}; - -// Compute y = x @ w - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero, - cudaStream_t alt_stream -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - - uint32_t* x_map = w->cuda_x_map; - const half* x_mapped = x; - if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) - { - CudaBuffers* buffers = get_buffers(w->device); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - x_map = NULL; - } - - int block_size_z; - if (w->width == 4096) block_size_z = 384; // 7B - else if (w->width == 11008) block_size_z = 256; - else if (w->width == 5120) block_size_z = 384; // 13B - else if (w->width == 13824) block_size_z = 256; - else if (w->width == 6656) block_size_z = 256; // 33B - else if (w->width == 17920) block_size_z = 128; - else block_size_z = 256; - - //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); - - dim3 threads(THREADS_X, THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height + threads.y - 1) / threads.y, - (dim + block_size_z - 1) / block_size_z - ); - - fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); -} - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - bool no_zero, - const cublasHandle_t handle -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - CudaBuffers* buffers = get_buffers(w->device); - - const half* x_mapped = x; - if (w->cuda_x_map) - { - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - } - - w->reconstruct(buffers->temp_dq); - - const half alpha = __float2half(1.0f); - const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); - cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); - -// const float alpha = 1.0f; -// const float beta = no_zero ? 1.0f : 0.0f; -// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, -// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); -} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh deleted file mode 100644 index 4c7a6669..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh +++ /dev/null @@ -1,37 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matmul_cuh -#define _q4_matmul_cuh - -#include -#include -#include -#include -#include - -#include "q4_matrix.cuh" -#include "../tuning.h" - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero, - cudaStream_t alt_stream -); - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - bool no_zero, - const cublasHandle_t handle -); - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu deleted file mode 100644 index 1f32e6b8..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu +++ /dev/null @@ -1,220 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include -#include "q4_matrix.cuh" -#include -#include "../util.cuh" -#include "../matrix.cuh" - -using namespace std; - -const int UNSHUF_BLOCKSIZE_X = 64; - -const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column -const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows - -vector g_q4_matrices; - -void g_q4_keep_matrix(Q4Matrix* m) -{ - g_q4_matrices.push_back(m); -} - -void g_q4_free_matrices() -{ - for (const auto& m : g_q4_matrices) delete m; - g_q4_matrices.clear(); -} - -Q4Matrix::Q4Matrix -( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device -) : - height(_height), - width(_width), - groups(_groups), - device(_device) -{ - cudaSetDevice(device); - - cuda_qweight = _qweight; - cuda_qzeros = _qzeros; - cuda_scales = _scales; - - groupsize = height / groups; - - if (_g_idx) make_sequential(_g_idx); -} - -Q4Matrix::~Q4Matrix() -{ -} - -// Make sequential - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint32_t* __restrict__ x_map, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - int w_new2_row = blockIdx.y; - - int x_map_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = x_map[x_map_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) -{ - uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Move to CUDA - - cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); - dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); - - // Replace qweights - - cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); -} - -__global__ void reconstruct_kernel -( - const uint32_t* __restrict__ w, - half* __restrict__ out, // (y) - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int width, - const int groupsize -) -{ - // Start of block - - int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; - int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; - - // Views - - MatrixView_q4_column w_(w, height, width); - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, height / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); - - // Groupsize version - - int group = row / groupsize; - - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F; - - uint32_t w_read = w_.item_uint32_t(row, column); - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += 4) - { - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } -} - -void Q4Matrix::reconstruct(half* out) -{ - dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height / 8 + threads.y - 1) / threads.y, - 1 - ); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh deleted file mode 100644 index 49431dc9..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matrix_cuh -#define _q4_matrix_cuh - -#include -#include -#include - -class Q4Matrix -{ -public: - - int device; - - int height; - int width; - int groups; - int groupsize; - - uint32_t* cuda_qweight = NULL; - uint32_t* cuda_qzeros = NULL; - half* cuda_scales = NULL; - uint32_t* cuda_x_map = NULL; - - Q4Matrix - ( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device - ); - - ~Q4Matrix(); - - void reconstruct(half* out); - -private: - - void make_sequential(const uint32_t* cpu_g_idx); - -}; - -void g_q4_keep_matrix(Q4Matrix* m); -void g_q4_free_matrices(); - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/exllama_ext.cpp b/backends/gaudi/server/exllama_kernels/exllama_kernels/exllama_ext.cpp deleted file mode 100644 index f2df80e8..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/exllama_ext.cpp +++ /dev/null @@ -1,253 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include -#include -#include -#include -#include -#include -#include -#include "util.cuh" -#include "tuning.h" -#include "cuda_buffers.cuh" -#include "cuda_func/q4_matrix.cuh" -#include "cuda_func/q4_matmul.cuh" -#include "cuda_func/column_remap.cuh" - -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. - -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } -} - -// Some decluttering macros - -#define STRINGIFY_(__x) #__x -#define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ - TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) - -#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; -} - - -// Tuning parameters - -ExLlamaTuning tuningParams; - -void set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; -} - - -// Release all unmanaged objects allocated by the extension - -void cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); -} - - -// Prepare buffers for forward pass - -void prepare_buffers -( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq -) -{ - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); - - prepare_buffers_cuda - ( - device_index, - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); -} - - -// Create Q4Matrix, return handle - -uintptr_t make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix - ( - height, - width, - groups, - - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), - - device - ); - - g_q4_keep_matrix(m); - return reinterpret_cast (m); -} - - -// Matmul half @ quant -> half - -void q4_matmul -( - torch::Tensor x, - uintptr_t w, - torch::Tensor out -) -{ - Q4Matrix* wm = reinterpret_cast (w); - - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - false, - stream - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - false, - at::cuda::getCurrentCUDABlasHandle() - ); - } -} - - -// Remap columns in half tensor - -void column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - - int height = x.size(0); - int width = x.size(1); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - column_remap_cuda - ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() - ); -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); -} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/hip_compat.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/hip_compat.cuh deleted file mode 100644 index f2a3dcad..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/hip_compat.cuh +++ /dev/null @@ -1,52 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _hip_compat_cuh -#define _hip_compat_cuh - -// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6. -__device__ __forceinline__ __half __compat_hrcp(__half x) { - return __half_raw{ - static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; -} - -__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{ - _Float16_2{static_cast<_Float16>(1.0f), - static_cast<_Float16>(1.0f)} / x.data}; -} - -#define hrcp __compat_hrcp -#define h2rcp __compat_h2rcp - -// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf. -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); -} -#define hipblasHgemm __compat_hipblasHgemm - -// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. -#define rocblas_handle hipblasHandle_t -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_get_stream hipblasGetStream -#define rocblas_set_stream hipblasSetStream -#define rocblas_hgemm __compat_hipblasHgemm - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/matrix.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/matrix.cuh deleted file mode 100644 index 2fd5ab0b..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/matrix.cuh +++ /dev/null @@ -1,294 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _matrix_cuh -#define _matrix_cuh - -#include -#include - -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } -}; - -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } -}; - -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } -}; - -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } -}; - -// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale - -__device__ __forceinline__ half2 dot_product_8 -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - -// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) -// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; -// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; -// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; - - half2 tmp = __hmul2(*h_ptr++, v_01); - tmp = __hfma2(*h_ptr++, v_23, tmp); - tmp = __hfma2(*h_ptr++, v_45, tmp); - tmp = __hfma2(*h_ptr++, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half* h_ptr = h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(*h_ptr++, v_0); - tmp = __hfma(*h_ptr++, v_1, tmp); - tmp = __hfma(*h_ptr++, v_2, tmp); - tmp = __hfma(*h_ptr++, v_3, tmp); - tmp = __hfma(*h_ptr++, v_4, tmp); - tmp = __hfma(*h_ptr++, v_5, tmp); - tmp = __hfma(*h_ptr++, v_6, tmp); - tmp = __hfma(*h_ptr++, v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map - -__device__ __forceinline__ half2 dot_product_8_x_map -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - - half h_0 = h_ptr[*x_map_ptr++]; - half h_1 = h_ptr[*x_map_ptr++]; - half h_2 = h_ptr[*x_map_ptr++]; - half h_3 = h_ptr[*x_map_ptr++]; - half h_4 = h_ptr[*x_map_ptr++]; - half h_5 = h_ptr[*x_map_ptr++]; - half h_6 = h_ptr[*x_map_ptr++]; - half h_7 = h_ptr[*x_map_ptr++]; - - half2 h_01 = __halves2half2(h_0, h_1); - half2 h_23 = __halves2half2(h_2, h_3); - half2 h_45 = __halves2half2(h_4, h_5); - half2 h_67 = __halves2half2(h_6, h_7); - - half2 tmp = __hmul2(h_01, v_01); - tmp = __hfma2(h_23, v_23, tmp); - tmp = __hfma2(h_45, v_45, tmp); - tmp = __hfma2(h_67, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_x_map_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); - tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/tuning.h b/backends/gaudi/server/exllama_kernels/exllama_kernels/tuning.h deleted file mode 100644 index 770ca46a..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/tuning.h +++ /dev/null @@ -1,13 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _tuning_h -#define _tuning_h - -struct ExLlamaTuning -{ - int matmul_recons_thd; - bool matmul_fused_remap; - bool matmul_no_half2; -}; - -#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/util.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/util.cuh deleted file mode 100644 index 7b397573..00000000 --- a/backends/gaudi/server/exllama_kernels/exllama_kernels/util.cuh +++ /dev/null @@ -1,33 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _util_cuh -#define _util_cuh - -#include -#include -#include -#include - -#if defined(USE_ROCM) -#define cudaUnspecified hipErrorUnknown -#else -#define cudaUnspecified cudaErrorApiFailureBase -#endif - -// React to failure on return code != cudaSuccess - -#define _cuda_check(fn) \ -do { \ - {_cuda_err = fn;} \ - if (_cuda_err != cudaSuccess) goto _cuda_fail; \ -} while(false) - -// React to failure on return code == 0 - -#define _alloc_check(fn) \ -do { \ - if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ - else _cuda_err = cudaSuccess; \ -} while(false) - -#endif diff --git a/backends/gaudi/server/exllama_kernels/setup.py b/backends/gaudi/server/exllama_kernels/setup.py deleted file mode 100644 index cc307bf0..00000000 --- a/backends/gaudi/server/exllama_kernels/setup.py +++ /dev/null @@ -1,32 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -import torch - -extra_cuda_cflags = [] -extra_cflags = [] -if torch.version.hip: - extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] - extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] - -extra_compile_args = { - "cxx": extra_cflags, - "nvcc": extra_cuda_cflags, -} - -setup( - name="exllama_kernels", - ext_modules=[ - CUDAExtension( - name="exllama_kernels", - sources=[ - "exllama_kernels/exllama_ext.cpp", - "exllama_kernels/cuda_buffers.cu", - "exllama_kernels/cuda_func/column_remap.cu", - "exllama_kernels/cuda_func/q4_matmul.cu", - "exllama_kernels/cuda_func/q4_matrix.cu", - ], - extra_compile_args=extra_compile_args, - ) - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/config.h b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/config.h deleted file mode 100644 index 32a1a37d..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/config.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _config_h -#define _config_h - -#define MAX_Q_GEMM_ROWS 50 -#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS - -#define QMODE_2BIT 1 -#define QMODE_3BIT 1 -#define QMODE_4BIT 1 -#define QMODE_5BIT 1 -#define QMODE_6BIT 0 -#define QMODE_8BIT 0 - - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h deleted file mode 100644 index 919703a8..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _util_h -#define _util_h - -#define DBGS(__x) printf("%s\n", __x) -#define DBGI(__x) printf("%s: %i\n", #__x, __x) -#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) -#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) -#define DBGF(__x) printf("%s: %f\n", #__x, __x) -#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) -#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh deleted file mode 100644 index 12684ff8..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh +++ /dev/null @@ -1,56 +0,0 @@ -#ifndef _compat_cuh -#define _compat_cuh - -// atomicAdd for half types, to support CC < 7.x - -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); -} - -// atomicAdd for half2 types - -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); -} - -// - -#if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) - -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } - -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif - -#endif -#endif - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh deleted file mode 100644 index a72bc7bc..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh +++ /dev/null @@ -1,121 +0,0 @@ -#ifndef _matrix_view_cuh -#define _matrix_view_cuh - -#include -#include - -#include "quant/qdq_util.cuh" - -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } - - __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const - { - half2* ptr = (half2*) item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __low2half(i01); - items[1] = __high2half(i01); - items[2] = __low2half(i23); - items[3] = __high2half(i23); - } - __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2float(__low2half(i01)); - items[1] = __half2float(__high2half(i01)); - items[2] = __half2float(__low2half(i23)); - items[3] = __half2float(__high2half(i23)); - } - - __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2half2(__low2half(i01)); - items[1] = __half2half2(__high2half(i01)); - items[2] = __half2half2(__low2half(i23)); - items[3] = __half2half2(__high2half(i23)); - } -}; - -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } - - __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) - { - half2 v01 = __halves2half2(v0, v1); - half2 v23 = __halves2half2(v2, v3); - half2* ptr = (half2*) item_ptr(row, column); - ptr[0] = v01; - ptr[1] = v23; - } -}; - -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - items[2] = (d >> 8) & 0x0f; - items[3] = (d >> 12) & 0x0f; - } -}; - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu deleted file mode 100644 index 5b99f1ba..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu +++ /dev/null @@ -1,220 +0,0 @@ -#include "q_gemm.cuh" -#include "util.cuh" -#include "matrix_view.cuh" -#include "../config.h" - -#include "quant/qdq_2.cuh" -#include "quant/qdq_3.cuh" -#include "quant/qdq_4.cuh" -#include "quant/qdq_5.cuh" -#include "quant/qdq_6.cuh" -#include "quant/qdq_8.cuh" - -#define GPTQ_BLOCK_KN_SIZE 128 -#define GPTQ_BLOCK_M_SIZE_MAX 8 -#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32) - -#define EXL2_BLOCK_KN_SIZE 64 -#define EXL2_BLOCK_M_SIZE_MAX 8 -#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) - -#define CLEAR_N_SIZE 256 - -#include "q_gemm_kernel.cuh" -#include "q_gemm_kernel_gptq.cuh" - -void gemm_half_q_half_cuda_part -( - const half* a, - QMatrix* b, - half* c, - int size_m, - int size_n, - int size_k, - int m_count, - bool clear, - const half* r_weights, - int r_weights_stride, - bool mul_r_weights -) -{ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (!b->is_gptq) - { - dim3 blockDim, gridDim; - blockDim.x = EXL2_BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); - gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE); - - fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); - - kernel<<>> - ( - a, - b->cuda_q_weight, - b->cuda_q_scale, - b->cuda_q_scale_max, - c, - size_m, - size_n, - size_k, - b->groups, - b->cuda_q_group_map, - b->cuda_q_perm, - b->rows_8, - b->rows_6, - b->rows_5, - b->rows_4, - b->rows_3, - b->rows_2, - clear, - r_weights, - r_weights_stride - ); - } - else - { - dim3 blockDim, gridDim; - blockDim.x = GPTQ_BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4); - gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE); - - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights); - -// DBGX((uint64_t) r_weights); -// if (r_weights) -// print_global_mem(r_weights, 1, 1, 1); -// DBGI(r_weights_stride); - - kernel<<>> - ( - a, - b->cuda_q_weight, - b->cuda_gptq_qzeros, - b->cuda_gptq_scales, - c, - size_m, - size_n, - size_k, - b->groups, - b->gptq_groupsize, - b->cuda_q_perm, - b->rows_4, - clear, - r_weights, - r_weights_stride - ); - } -} - -void gemm_half_q_half_cuda -( - cublasHandle_t cublas_handle, - const half* a, - QMatrix* b, - half* c, - int size_m, - int size_n, - int size_k, - bool clear, - half* temp_dq, - bool force_cuda, - const half* r_weights, - const int r_weights_stride, - bool mul_r_weights -) -{ - if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) - { - // Reconstruct FP16 matrix, then cuBLAS - - if (!temp_dq) temp_dq = b->temp_dq; - b->reconstruct(temp_dq); - - //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); - - const half alpha = __float2half(1.0f); - const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); - cublasHgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - size_n, size_m, size_k, - &alpha, temp_dq, size_n, - a, size_k, - &beta, c, size_n); - - //const float alpha = 1.0f; - //const float beta = clear ? 0.0f : 1.0f; - //cublasSgemmEx(cublas_handle, - // CUBLAS_OP_N, - // CUBLAS_OP_N, - // size_n, size_m, size_k, - // &alpha, temp_dq, CUDA_R_16F, size_n, - // a, CUDA_R_16F, size_k, - // &beta, c, CUDA_R_16F, size_n); - - //const float alpha = 1.0f; - //const float beta = clear ? 0.0f : 1.0f; - //cublasGemmEx(cublas_handle, - // CUBLAS_OP_N, CUBLAS_OP_N, - // size_n, size_m, size_k, - // &alpha, temp_dq, CUDA_R_16F, size_n, - // a, CUDA_R_16F, size_k, - // &beta, c, CUDA_R_16F, size_n, - // CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP); - } - else - { - // Quantized matmul - - int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX; - int max_chunks = size_m / block_m_size_max; - int last_chunk = max_chunks * block_m_size_max; - int last_chunk_size = size_m - last_chunk; - - if (max_chunks) - { - gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights); - } - - if (last_chunk_size) - { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights); - } - } -} - -__global__ void clear_kernel -( - half* __restrict__ c, - const int size_m, - const int size_n -) -{ - int m = blockIdx.y; - int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8; - if (n >= size_n) return; - int4* c_ptr = (int4*)(c + m * size_n + n); - *c_ptr = {}; -} - -void clear_tensor_cuda -( - half* c, - int size_m, - int size_n -) -{ -// dim3 blockDim, gridDim; -// blockDim.x = CLEAR_N_SIZE; -// blockDim.y = 1; -// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); -// gridDim.y = size_m; -// clear_kernel<<>>(c, size_m, size_n); -} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh deleted file mode 100644 index e49457f3..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _q_gemm_cuh -#define _q_gemm_cuh - -#include -#include -#include -#include -#include - -#include "q_matrix.cuh" - -void gemm_half_q_half_cuda -( - cublasHandle_t cublas_handle, - const half* a, - QMatrix* b, - half* c, - int size_m, - int size_n, - int size_k, - bool clear = false, - half* reconstruct = NULL, - bool force_cuda = false, - const half* r_weights = NULL, - const int r_weights_stride = 0, - bool mul_r_weights = false -); - -void clear_tensor_cuda -( - half* c, - int size_m, - int size_n -); - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh deleted file mode 100644 index 9cd2ba01..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh +++ /dev/null @@ -1,580 +0,0 @@ -#include "compat.cuh" - -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); -} - -__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); -} - -__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); -} - -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); -} - -__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); -} - -__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); -} - -__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) -{ - // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 - - float result = {}; - #pragma unroll - for (int i = 0; i < 4; i++) - { - half2 w01 = dq[i]; - float w0 = __low2float(w01); - float w1 = __high2float(w01); - float x0 = __half2float(*a_ptr++); - float x1 = __half2float(*a_ptr++); - result = fma(w0, x0, result); - result = fma(w1, x1, result); - } - float qs = __half2float(qs_h); - result *= qs; - half result_h = __float2half_rn(result); - return __hadd(result_h, g_result); -} - -__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); -} - -__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); -} - - -typedef void (*fp_gemm_half_q_half_kernel) -( - const half*, - const uint32_t*, - const uint32_t*, - const half*, - half*, - const int, - const int, - const int, - const int, - const uint16_t*, - const uint16_t*, - const int, - const int, - const int, - const int, - const int, - const int, - const bool, - const half*, - const int -); - -template -__global__ void gemm_half_q_half_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, - const uint32_t* __restrict__ b_q_scale, - const half* __restrict__ b_q_scale_max, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const uint16_t* __restrict__ b_q_group_map, - const uint16_t* __restrict__ b_q_perm, - const int rows_8, - const int rows_6, - const int rows_5, - const int rows_4, - const int rows_3, - const int rows_2, - const bool clear, - const half* r_weights, - const int r_weights_stride -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); - - int t = threadIdx.x; - - // Block - - int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; - - int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k); - int n = offset_n + t * 4; - - // Read weights - - half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; - if constexpr (use_r_weights) - { - uint16_t any_w = 0; - const half* w_ptr = r_weights; - for (int m = 0; m < m_count; ++m) - { - weights[m].as_half = *w_ptr; - w_ptr += r_weights_stride; - any_w |= weights[m].as_uint16; - } - if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) - } - - // Preload block_a - - __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - half a0 = a_ptr[b_q_perm[offset_k + t]]; -// half a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } - } - - // Clear - - if (n >= size_n) return; - - if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0; - } - - __syncthreads(); - - // Find initial group - - //int group = offset_k / groupsize; - int group = b_q_group_map[offset_k * 2]; - -// if (offset_m == 0 && t == 0) -// DBGI2(offset_k, group); - - // Preload scales - - half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; - - //int groups_in_block = DIVIDE((end_k - offset_k), groupsize); - int temp_k = offset_k; - for (int g = 0; temp_k < end_k; g++) - { - int qscales[4]; - b_q_scale_.item4(qscales, group + g, n); - qscales[0]++; - qscales[1]++; - qscales[2]++; - qscales[3]++; - half maxscale = b_q_scale_max[group + g]; - scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); - scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); - scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); - scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); - temp_k += b_q_group_map[temp_k * 2 + 1]; - } - - // a, b offset - - int pre_rows_8 = min(rows_8, offset_k); - int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; - int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; - int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; - int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; - int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; - int qk = 0; - qk += pre_rows_8 / 32 * 8; - qk += pre_rows_6 / 32 * 6; - qk += pre_rows_5 / 32 * 5; - qk += pre_rows_4 / 32 * 4; - qk += pre_rows_3 / 32 * 3; - qk += pre_rows_2 / 32 * 2; - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = EXL2_BLOCK_KN_SIZE; - - // Initial group - - int scales_idx = 0; - half qs_h0 = scales[scales_idx][0]; - half qs_h1 = scales[scales_idx][1]; - half qs_h2 = scales[scales_idx][2]; - half qs_h3 = scales[scales_idx][3]; - int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; - - // Column result - - half block_c[m_count][4] = {}; - - // Dequantize groups - - int k = offset_k; - - while (k < rows_8 && k < end_k) - { - if (k == nextgroup) - { - group++; - scales_idx++; - qs_h0 = scales[scales_idx][0]; - qs_h1 = scales[scales_idx][1]; - qs_h2 = scales[scales_idx][2]; - qs_h3 = scales[scales_idx][3]; - nextgroup += b_q_group_map[k * 2 + 1]; - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n); - - for (int m = 0; m < m_count; m++) - { - if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } - block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); - block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); - block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); - block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); - } - a_ptr += 8; - } - k += 32; - } - - while (k < rows_6 && k < end_k) - { - if (k == nextgroup) - { - group++; - scales_idx++; - qs_h0 = scales[scales_idx][0]; - qs_h1 = scales[scales_idx][1]; - qs_h2 = scales[scales_idx][2]; - qs_h3 = scales[scales_idx][3]; - nextgroup += b_q_group_map[k * 2 + 1]; - } - - #pragma unroll - for (int j = 0; j < 2; j++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][8]; - dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); - dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); - dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); - dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); - - for (int m = 0; m < m_count; m++) - { - if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } - block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); - block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); - block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); - block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); - } - a_ptr += 16; - } - k += 32; - } - - while (k < rows_5 && k < end_k) - { - if (k == nextgroup) - { - group++; - scales_idx++; - qs_h0 = scales[scales_idx][0]; - qs_h1 = scales[scales_idx][1]; - qs_h2 = scales[scales_idx][2]; - qs_h3 = scales[scales_idx][3]; - nextgroup += b_q_group_map[k * 2 + 1]; - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - int4 load_int4[5]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[3] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[4] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n); - dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n); - dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n); - dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n); - - for (int m = 0; m < m_count; m++) - { - if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } - block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); - block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); - block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); - block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); - } - a_ptr += 32; - } - - k += 32; - } - - while (k < rows_4 && k < end_k) - { - if (k == nextgroup) - { - group++; - scales_idx++; - qs_h0 = scales[scales_idx][0]; - qs_h1 = scales[scales_idx][1]; - qs_h2 = scales[scales_idx][2]; - qs_h3 = scales[scales_idx][3]; - nextgroup += b_q_group_map[k * 2 + 1]; - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - int4 load_int4[1]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_4bit_8(load_int4[0].x, dq[0], size_n); - dequant_4bit_8(load_int4[0].y, dq[1], size_n); - dequant_4bit_8(load_int4[0].z, dq[2], size_n); - dequant_4bit_8(load_int4[0].w, dq[3], size_n); - - for (int m = 0; m < m_count; m++) - { - if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } - block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); - block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); - block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); - block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); - } - a_ptr += 8; - } - k += 32; - } - - while (k < rows_3 && k < end_k) - { - if (k == nextgroup) - { - group++; - scales_idx++; - qs_h0 = scales[scales_idx][0]; - qs_h1 = scales[scales_idx][1]; - qs_h2 = scales[scales_idx][2]; - qs_h3 = scales[scales_idx][3]; - nextgroup += b_q_group_map[k * 2 + 1]; - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); - - for (int m = 0; m < m_count; m++) - { - if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } - block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); - block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); - block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); - block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); - } - a_ptr += 32; - } - k += 32; - } - - while (k < rows_2 && k < end_k) - { - if (k == nextgroup) - { - group++; - scales_idx++; - qs_h0 = scales[scales_idx][0]; - qs_h1 = scales[scales_idx][1]; - qs_h2 = scales[scales_idx][2]; - qs_h3 = scales[scales_idx][3]; - nextgroup += b_q_group_map[k * 2 + 1]; - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - int4 load_int4[1]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][8]; - dequant_2bit_16(load_int4[0].x, dq[0], size_n); - dequant_2bit_16(load_int4[0].y, dq[1], size_n); - dequant_2bit_16(load_int4[0].z, dq[2], size_n); - dequant_2bit_16(load_int4[0].w, dq[3], size_n); - - for (int m = 0; m < m_count; m++) - { - if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } - block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); - block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); - block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); - block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); - } - - a_ptr += 16; - } - k += 16; - } - - // Accumulate column sums in c - - for (int m = 0; m < m_count; m++) - { - half2* out = (half2*)c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - - if constexpr (mul_r_weights) - { - half2 w_mul2 = __half2half2(weights[m].as_half); - result01 = __hmul2(result01, w_mul2); - result23 = __hmul2(result23, w_mul2); - } - - atomicAdd(out , result01); - atomicAdd(out + 1, result23); -// *out = result01; -// *(out + 1) = result23; - } -} - -template -struct map_m_count_exl2 { - static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) - { - #if EXL2_BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>; - #endif - #if EXL2_BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>; - #endif - #if EXL2_BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>; - #endif - #if EXL2_BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>; - #endif - #if EXL2_BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>; - #endif - #if EXL2_BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>; - #endif - #if EXL2_BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>; - #endif - #if EXL2_BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>; - #endif - return NULL; - } -}; - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights) -{ - if (!r_weights && !mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); - if (!r_weights && mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); - if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count); - if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count); - return NULL; -} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh deleted file mode 100644 index f816fd9d..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh +++ /dev/null @@ -1,273 +0,0 @@ -#include "compat.cuh" - -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hadd2(result, g_result); -} - -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __half2float(__low2half(result)) + __half2float(__high2half(result)); -} - -__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return result; -} - -typedef void (*fp_gemm_half_q_half_gptq_kernel) -( - const half*, - const uint32_t*, - const uint32_t*, - const half*, - half*, - const int, - const int, - const int, - const int, - const int, - const uint16_t*, - const int, - const bool, - const half*, - const int -); - -template -__global__ void gemm_half_q_half_gptq_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int groupsize, - const uint16_t* __restrict__ b_q_perm, - const int rows_4, - const bool clear, - const half* r_weights, - const int r_weights_stride -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - - int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE; - - int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Read weights - - half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; - if constexpr (use_r_weights) - { - uint16_t any_w = 0; - const half* w_ptr = r_weights; - for (int m = 0; m < m_count; ++m) - { - weights[m].as_half = *w_ptr; - w_ptr += r_weights_stride; - any_w |= weights[m].as_uint16; - } - if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) - } - - // Preload block_a - - __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } - } - - // Zero output - - if (n >= size_n) return; - - if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; - } - - __syncthreads(); - - // Find initial group - - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = GPTQ_BLOCK_KN_SIZE; - - // Initial group - - int zeros[4]; - half2 scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); - -// __syncthreads(); - - // Column result - - half2 block_c[m_count][4] = {}; - - // Dequantize and multiply - - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][4]; - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } - block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); - } - - b_ptr += size_n; - a_ptr += 8; - } - - k += 32; - } - - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0])); - half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1])); - half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2])); - half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3])); - half2 result01 = __halves2half2(result0, result1); - half2 result23 = __halves2half2(result2, result3); - - if constexpr (mul_r_weights) - { - half2 w_mul2 = __half2half2(weights[m].as_half); - result01 = __hmul2(result01, w_mul2); - result23 = __hmul2(result23, w_mul2); - } - - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } -} - -template -struct map_m_count_gptq { - static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count) - { - #if GPTQ_BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>; - #endif - #if GPTQ_BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>; - #endif - #if GPTQ_BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>; - #endif - #if GPTQ_BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>; - #endif - #if GPTQ_BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>; - #endif - #if GPTQ_BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>; - #endif - #if GPTQ_BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>; - #endif - #if GPTQ_BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>; - #endif - return NULL; - } -}; - -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights) -{ - if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); - if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); - return NULL; -} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu deleted file mode 100644 index f7a91e29..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ /dev/null @@ -1,650 +0,0 @@ -#include "q_matrix.cuh" -#include "matrix_view.cuh" -#include "util.cuh" - -#include "quant/qdq_2.cuh" -#include "quant/qdq_3.cuh" -#include "quant/qdq_4.cuh" -#include "quant/qdq_5.cuh" -#include "quant/qdq_6.cuh" -#include "quant/qdq_8.cuh" - -#define BLOCK_KN_SIZE 128 - -#define THREADS_X 32 -#define THREADS_Y 32 - -// Shuffle quantized data on load - -__global__ void shuffle_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n, - const int rows_8, - const int rows_6, - const int rows_5, - const int rows_4, - const int rows_3, - const int rows_2 -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } - while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; } - while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; } - while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } - while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } - while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } -} - - -// QMatrix constructor - -QMatrix::QMatrix -( - const int _device, - const int _height, - const int _width, - const int _groups, - - uint32_t* _q_weight, - uint16_t* _q_perm, - uint16_t* _q_invperm, - uint32_t* _q_scale, - half* _q_scale_max, - uint16_t* _q_groups, - uint16_t* _q_group_map, - - uint32_t* _gptq_qzeros, - half* _gptq_scales, - uint32_t* _gptq_g_idx, - - half* _temp_dq -) : - device(_device), - height(_height), - width(_width), - groups(_groups), - temp_dq(_temp_dq) -{ - cudaSetDevice(device); - - failed = false; - - cuda_q_weight = _q_weight; - cuda_q_perm = _q_perm; - cuda_q_invperm = _q_invperm; - cuda_q_scale = _q_scale; - cuda_q_scale_max = _q_scale_max; - cuda_q_groups = _q_groups; - cuda_q_group_map = _q_group_map; - cuda_gptq_qzeros = _gptq_qzeros; - cuda_gptq_scales = _gptq_scales; - - is_gptq = (_gptq_qzeros != NULL); - - if (is_gptq) - { - gptq_groupsize = 1; - while (gptq_groupsize * groups < height) gptq_groupsize *= 2; - } - - // Create group map - - rows_8 = 0; - rows_6 = 0; - rows_5 = 0; - rows_4 = 0; - rows_3 = 0; - rows_2 = 0; - - if (!is_gptq) - { - uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); - cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); - - int row = 0; - for (int i = 0; i < groups; i++) - { - int bits = cpu_q_groups[i * 2]; - - int rows; - if (i < groups - 1) - { - int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; - rows = qrows * 32 / bits; - } - else rows = height - row; - - if (bits == 8) rows_8 += rows; - if (bits == 6) rows_6 += rows; - if (bits == 5) rows_5 += rows; - if (bits == 4) rows_4 += rows; - if (bits == 3) rows_3 += rows; - if (bits == 2) rows_2 += rows; - row += rows; - } - - free(cpu_q_groups); - - rows_6 += rows_8; - rows_5 += rows_6; - rows_4 += rows_5; - rows_3 += rows_4; - rows_2 += rows_3; - } - else - { - rows_4 = height; - rows_3 = height; - rows_2 = height; - - if (_gptq_g_idx) - { - if (!make_sequential(_gptq_g_idx)) - { - failed = true; - //printf("FAIL\n"); - return; - } - } - } - -// DBGI(rows_8); -// DBGI(rows_6); -// DBGI(rows_5); -// DBGI(rows_4); -// DBGI(rows_3); -// DBGI(rows_2); - - // Shuffle quantized data - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = 1; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); -} - -QMatrix::~QMatrix() -{ -} - -// Reconstruct b[k,n] (GPTQ) - -__global__ void reconstruct_gptq_kernel -( - const uint32_t* __restrict__ b_q_weight, - const uint16_t* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - //const uint16_t* __restrict__ b_q_groups, - const int size_k, - const int size_n, - const int groupsize, - const int groups, - half* __restrict__ b, - const int rows_4 -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - - __shared__ uint16_t perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } - - // Column - - int n = offset_n + t * 4; - if (n >= size_n) return; - - // Find initial group - - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // b offset - - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - - // Initial zeros/scale - - int zeros[4]; - half2 scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); - - __syncthreads(); - - int k = offset_k; - int lk = 0; - - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); - } - - for (int p = 0; p < 4; p++) - { - half2 dq[4][4]; - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - } - k += 32; - } -} - - -// Reconstruct b[k,n] - -__global__ void reconstruct_kernel -( - const uint32_t* __restrict__ b_q_weight, - const uint16_t* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_q_scale, - const half* __restrict__ b_q_scale_max, - const uint16_t* __restrict__ b_q_group_map, - const int size_k, - const int size_n, - //const int groupsize, - const int groups, - half* __restrict__ b, - const int rows_8, - const int rows_6, - const int rows_5, - const int rows_4, - const int rows_3, - const int rows_2 -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x; - - // Preload remapping table - - int t = threadIdx.x; - __shared__ uint16_t perm[BLOCK_KN_SIZE]; - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - - // Column - - int n = offset_n + t; - if (n >= size_n) return; - - // Find initial group - - // int group = offset_k / groupsize; - int group = b_q_group_map[offset_k * 2]; - - int pre_rows_8 = min(rows_8, offset_k); - int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; - int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; - int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; - int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; - int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; - int qk = 0; - qk += pre_rows_8 / 32 * 8; - qk += pre_rows_6 / 32 * 6; - qk += pre_rows_5 / 32 * 5; - qk += pre_rows_4 / 32 * 4; - qk += pre_rows_3 / 32 * 3; - qk += pre_rows_2 / 32 * 2; - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - - half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); - half2 qs_h2 = __halves2half2(qs_h, qs_h); - int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - int k = offset_k; - int lk = 0; - - __syncthreads(); - - while (k < rows_8 && k < end_k) - { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 4; p++) - { - half2 dq[4]; - uint32_t q_0 = *b_ptr; b_ptr += size_n; - uint32_t q_1 = *b_ptr; b_ptr += size_n; - dequant_8bit_8(q_0, q_1, dq, size_n); - for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); - half* dqh = (half*) dq; - for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); - } - k += 32; - } - - while (k < rows_6 && k < end_k) - { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 2; p++) - { - half2 dq[8]; - uint32_t q_0 = *b_ptr; b_ptr += size_n; - uint32_t q_1 = *b_ptr; b_ptr += size_n; - uint32_t q_2 = *b_ptr; b_ptr += size_n; - dequant_6bit_16(q_0, q_1, q_2, dq, size_n); - for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); - half* dqh = (half*) dq; - for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); - } - k += 32; - } - - while (k < rows_5 && k < end_k) - { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 1; p++) - { - half2 dq[16]; - uint32_t q_0 = *b_ptr; b_ptr += size_n; - uint32_t q_1 = *b_ptr; b_ptr += size_n; - uint32_t q_2 = *b_ptr; b_ptr += size_n; - uint32_t q_3 = *b_ptr; b_ptr += size_n; - uint32_t q_4 = *b_ptr; b_ptr += size_n; - dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n); - for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); - half* dqh = (half*) dq; - for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); - } - k += 32; - } - - while (k < rows_4 && k < end_k) - { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 4; p++) - { - half2 dq[4]; - uint32_t q_0 = *b_ptr; b_ptr += size_n; - dequant_4bit_8(q_0, dq, size_n); - for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); - half* dqh = (half*) dq; - for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); - } - k += 32; - } - - while (k < rows_3 && k < end_k) - { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 1; p++) - { - half2 dq[16]; - uint32_t q_0 = *b_ptr; b_ptr += size_n; - uint32_t q_1 = *b_ptr; b_ptr += size_n; - uint32_t q_2 = *b_ptr; b_ptr += size_n; - dequant_3bit_32(q_0, q_1, q_2, dq, size_n); - for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); - half* dqh = (half*) dq; - for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); - } - k += 32; - } - - while (k < rows_2 && k < end_k) - { - if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } - for (int p = 0; p < 1; p++) - { - half2 dq[8]; - uint32_t q_0 = *b_ptr; b_ptr += size_n; - dequant_2bit_16(q_0, dq, size_n); - for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); - half* dqh = (half*) dq; - for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); - } - k += 16; - } -} - -void QMatrix::reconstruct(half* out) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (!is_gptq) - { - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_kernel<<>> - ( - cuda_q_weight, - cuda_q_perm, - cuda_q_scale, - cuda_q_scale_max, - cuda_q_group_map, - height, - width, - //groupsize, - groups, - out, - rows_8, - rows_6, - rows_5, - rows_4, - rows_3, - rows_2 - ); - } - else - { - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); - reconstruct_gptq_kernel<<>> - ( - cuda_q_weight, - cuda_q_perm, - cuda_gptq_qzeros, - cuda_gptq_scales, - //const uint16_t* __restrict__ b_q_groups, - height, - width, - gptq_groupsize, - groups, - out, - rows_4 - ); - } -} - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint16_t* __restrict__ q_perm, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - - int w_new2_row = blockIdx.y; - - int q_perm_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) -{ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - uint32_t* cuda_new_qweight = NULL; - cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - if (err != cudaSuccess) { - cudaError_t cuda_status = cudaGetLastError(); // Clear error - return false; - } - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Reduce to uint16_t - - uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map; - uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv; - for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row]; - for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row]; - - // Move to CUDA - - cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = height / 8; - - make_sequential_kernel<<>> - ( - cuda_q_weight, - cuda_new_qweight, - cuda_q_perm, - height / 8, - width - ); - - // Replace qweights - - cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); - - return true; -} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh deleted file mode 100644 index d36b8d66..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef _q_matrix_cuh -#define _q_matrix_cuh - -#include -#include -#include -#include - -#define MAX_SUPERGROUPS 16 - -class QMatrix -{ -public: - - int device; - bool is_gptq; - - int height; - int width; - int groups; - int gptq_groupsize; - - int rows_8; - int rows_6; - int rows_5; - int rows_4; - int rows_3; - int rows_2; - - uint32_t* cuda_q_weight = NULL; - uint16_t* cuda_q_perm = NULL; - uint16_t* cuda_q_invperm = NULL; - uint32_t* cuda_q_scale = NULL; - half* cuda_q_scale_max = NULL; - uint16_t* cuda_q_groups = NULL; - uint16_t* cuda_q_group_map = NULL; - uint32_t* cuda_gptq_qzeros = NULL; - half* cuda_gptq_scales = NULL; - - half* temp_dq; - - bool failed; - - QMatrix - ( - const int _device, - const int _height, - const int _width, - const int _groups, - - uint32_t* _q_weight, - uint16_t* _q_perm, - uint16_t* _q_invperm, - uint32_t* _q_scale, - half* _q_scale_max, - uint16_t* _q_groups, - uint16_t* _q_group_map, - - uint32_t* _gptq_qzeros, - half* _gptq_scales, - uint32_t* _gptq_g_idx, - - half* _temp_dq - ); - - ~QMatrix(); - - void reconstruct(half* out); - bool make_sequential(const uint32_t* cpu_g_idx); - -private: - -}; - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh deleted file mode 100644 index 90c18a0c..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef _qdq_2_cuh -#define _qdq_2_cuh - -#include "qdq_util.cuh" -#include "../../config.h" - -#if QMODE_2BIT == 1 - -// Permutation: -// -// ffddbb99 77553311 eeccaa88 66442200 - -__forceinline__ __device__ void shuffle_2bit_16 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - uint32_t qa0 = qa & 0x03; - uint32_t qa1 = (qa & 0x0c) >> 2; - qa >>= 4; - qb |= (qa1 << (i * 2 + 16)); - qb |= (qa0 << (i * 2)); - } - q[0] = qb; -} - -__forceinline__ __device__ void dequant_2bit_16 -( - const uint32_t q_0, - half2 (&dq)[8], - int stride -) -{ - const uint32_t c0 = 0x64006400; - const half y4_ = __float2half_rn(1.0f / 4.0f); - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y4 = __halves2half2(y4_, y4_); - const half2 y16 = __halves2half2(y16_, y16_); - const half2 y64 = __halves2half2(y64_, y64_); - const half z1_ = __float2half_rn(-1024.0f - 2.0f); - const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f); - const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f); - const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f); - const half2 z1 = __halves2half2(z1_, z1_); - const half2 z4 = __halves2half2(z4_, z4_); - const half2 z16 = __halves2half2(z16_, z16_); - const half2 z64 = __halves2half2(z64_, z64_); - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 - half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 - half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 - qa >>= 8; - half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 - half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 - half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 - half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 - - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y4, z4); - dq[2] = __hfma2(q2.as_half2, y16, z16); - dq[3] = __hfma2(q3.as_half2, y64, z64); - dq[4] = __hadd2(q4.as_half2, z1); - dq[5] = __hfma2(q5.as_half2, y4, z4); - dq[6] = __hfma2(q6.as_half2, y16, z16); - dq[7] = __hfma2(q7.as_half2, y64, z64); -} - -#else - -__forceinline__ __device__ void shuffle_2bit_16 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_2bit_16 -( - const uint32_t q_0, - half2 (&dq)[8], - int stride -) -{ - half dqh[16]; - for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2); - - for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); -} - -#endif - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh deleted file mode 100644 index 10117376..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh +++ /dev/null @@ -1,169 +0,0 @@ -#ifndef _qdq_3_cuh -#define _qdq_3_cuh - -#include "qdq_util.cuh" -#include "../../config.h" - -#if QMODE_3BIT == 1 - -// Permutation: -// -// v9997775 55333111 u8886664 44222000 (u, v lsb) -// vjjjhhhf ffdddbbb uiiiggge eecccaaa -// vtttrrrp ppnnnlll usssqqqo oommmkkk - -__forceinline__ __device__ void shuffle_3bit_32 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0 * stride]; - uint32_t qb = q[1 * stride]; - uint32_t qc = q[2 * stride]; - - // qa: aa999888 77766655 54443332 22111000 - // qb: lkkkjjji iihhhggg fffeeedd dcccbbba - // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll - - uint32_t qd = qc >> 26; - qc <<= 4; - qc |= qb >> 28; - qb <<= 2; - qb |= qa >> 30; - - // qa: ..999888 77766655 54443332 22111000 - // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa - // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk - // qd: vvvuuu - - uint32_t za = 0; - uint32_t zb = 0; - uint32_t zc = 0; - - for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } - - // za: 9997775 55333111 8886664 44222000 - // zb: jjjhhhf ffdddbbb iiiggge eecccaaa - // zc: tttrrrp ppnnnlll sssqqqo oommmkkk - // qd: vvvuuu - - za |= ((qd & 0x01) >> 0) << 15; - zb |= ((qd & 0x02) >> 1) << 15; - zc |= ((qd & 0x04) >> 2) << 15; - za |= ((qd & 0x08) >> 3) << 31; - zb |= ((qd & 0x10) >> 4) << 31; - zc |= ((qd & 0x20) >> 5) << 31; - - // za: v9997775 55333111 u8886664 44222000 (u, v lsb) - // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa - // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk - - q[0 * stride] = za; - q[1 * stride] = zb; - q[2 * stride] = zc; -} - -__forceinline__ __device__ void dequant_3bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - half2 (&dq)[16], - int stride -) -{ - const uint32_t c0 = 0x64006400; - const half y8_ = __float2half_rn(1.0f / 8.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y8 = __halves2half2(y8_, y8_); - const half2 y64 = __halves2half2(y64_, y64_); - const half z1_ = __float2half_rn(-1024.0f - 4.0f); - const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f); - const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f); - const half2 z1 = __halves2half2(z1_, z1_); - const half2 z8 = __halves2half2(z8_, z8_); - const half2 z64 = __halves2half2(z64_, z64_); - - uint32_t qa = q_0; - uint32_t qb = q_1; - uint32_t qc = q_2; - - half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 - qa >>= 6; - half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 - half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 - qa >>= 9; - qa &= 0x00010001; - half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 - half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 - qb >>= 6; - half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 - half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 - half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 - qb >>= 8; - qb &= 0x00020002; - half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 - half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 - qc >>= 6; - half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 - half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 - half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 - qc >>= 7; - qc &= 0x00040004; - half2_uint32 q15((qa | qb | qc) | c0); - - dq[ 0] = __hadd2( q0.as_half2, z1); - dq[ 1] = __hfma2( q1.as_half2, y8, z8); - dq[ 2] = __hadd2( q2.as_half2, z1); - dq[ 3] = __hfma2( q3.as_half2, y8, z8); - dq[ 4] = __hfma2( q4.as_half2, y64, z64); - dq[ 5] = __hadd2( q5.as_half2, z1); - dq[ 6] = __hfma2( q6.as_half2, y8, z8); - dq[ 7] = __hadd2( q7.as_half2, z1); - dq[ 8] = __hfma2( q8.as_half2, y8, z8); - dq[ 9] = __hfma2( q9.as_half2, y64, z64); - dq[10] = __hadd2(q10.as_half2, z1); - dq[11] = __hfma2(q11.as_half2, y8, z8); - dq[12] = __hadd2(q12.as_half2, z1); - dq[13] = __hfma2(q13.as_half2, y8, z8); - dq[14] = __hfma2(q14.as_half2, y64, z64); - dq[15] = __hadd2(q15.as_half2, z1); -} - -#else - -__forceinline__ __device__ void shuffle_3bit_32 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_3bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - half2 (&dq)[16], - int stride -) -{ - half dqh[32]; - for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4); - dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4); - for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4); - dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4); - for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4); - - for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); -} - -#endif - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh deleted file mode 100644 index ad95edb4..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh +++ /dev/null @@ -1,227 +0,0 @@ -#ifndef _qdq_4_cuh -#define _qdq_4_cuh - -#include "qdq_util.cuh" -#include "../../config.h" - -#if QMODE_4BIT == 1 - -// Permutation: -// -// 77775555 33331111 66664444 22220000 - -__forceinline__ __device__ void shuffle_4bit_8 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - uint32_t qa0 = qa & 0x0f; - uint32_t qa1 = (qa & 0xf0) >> 4; - qa >>= 8; - qb |= (qa1 << (i * 4 + 16)); - qb |= (qa0 << (i * 4)); - } - q[0] = qb; -} - -__forceinline__ __device__ void dequant_4bit_8 -( - const uint32_t q_0, - half2 (&dq)[4], - int stride -) -{ - const uint32_t c0 = 0x64006400; - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half2 y16 = __halves2half2(y16_, y16_); - const half z1_ = __float2half_rn(-1024.0f - 8.0f); - const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); - const half2 z1 = __halves2half2(z1_, z1_); - const half2 z16 = __halves2half2(z16_, z16_); - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 - - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y16, z16); - dq[2] = __hadd2(q2.as_half2, z1); - dq[3] = __hfma2(q3.as_half2, y16, z16); -} - -__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale -( - const uint32_t zero, - const half scale, - half2 (&z1z16)[2], - half2 (&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - - half2 scale2 = __half2half2(scale); - - z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); - z1z16[1] = __hmul2(scale2, __half2half2(z16)); - - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); - - y1y16[0] = __hmul2(scale2, __half2half2(y1)); - y1y16[1] = __hmul2(scale2, __half2half2(y16)); -} - -__forceinline__ __device__ void dequant_4bit_8_prep_zero -( - const uint32_t zero, - half2(&z1z16)[2], - half2(&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - - z1z16[0] = __half2half2(z1.as_half); - z1z16[1] = __half2half2(z16); - - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); - - y1y16[0] = __half2half2(y1); - y1y16[1] = __half2half2(y16); -} - - -__forceinline__ __device__ void dequant_4bit_8_gptq -( - const uint32_t q_0, - half2 (&dq)[4], - half2 (&z1z16)[2], - half2 (&y1y16)[2], - int stride, - bool scaled -) -{ - const uint32_t c0 = 0x64006400; - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) - - if (scaled) - { - dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) - dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); - } - else - { - dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) - dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) - } -} - -#else - -__forceinline__ __device__ void shuffle_4bit_8 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_4bit_8 -( - const uint32_t q_0, - half2 (&dq)[4], - int stride -) -{ - half dqh[8]; - for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); - - for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); -} - -__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale -( - const uint32_t zero, - const half scale, - half2 (&z1)[2], - half2 (&y1)[2] -) -{ - half z = __int2half_rn(-((int)zero)); - z = __hmul(z, scale); - z1[0] = __half2half2(z); - y1[0] = __half2half2(scale); -} - -__forceinline__ __device__ void dequant_4bit_8_prep_zero -( - const uint32_t zero, - half2(&z1)[2], - half2(&y1)[2] -) -{ - half z = __int2half_rn(-((int)zero)); - z1[0] = __half2half2(z); -} - -__forceinline__ __device__ void dequant_4bit_8_gptq -( - const uint32_t q_0, - half2 (&dq)[4], - half2 (&z1)[2], - half2 (&y1)[2], - int stride, - bool scaled -) -{ - half2 dqh2[8]; - - uint32_t qa = q_0; - for (int i = 0; i < 4; i++) - { - half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; - half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; - dqh2[i] = __halves2half2(d0, d1); - } - - if (scaled) - { - dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); - dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); - dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); - dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); - } - else - { - dq[0] = __hadd2(dqh2[0], z1[0]); - dq[1] = __hadd2(dqh2[1], z1[0]); - dq[2] = __hadd2(dqh2[2], z1[0]); - dq[3] = __hadd2(dqh2[3], z1[0]); - } -} - -#endif - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh deleted file mode 100644 index 78d81f92..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh +++ /dev/null @@ -1,207 +0,0 @@ -#ifndef _qdq_5_cuh -#define _qdq_5_cuh - -#include "qdq_util.cuh" -#include "../../config.h" - -#if QMODE_5BIT == 1 - -// Permutation: -// -// v5555533 33311111 u4444422 22200000 (u, v lsb) -// vbbbbb99 99977777 uaaaaa88 88866666 -// vhhhhhff fffddddd ugggggee eeeccccc -// vnnnnnll llljjjjj ummmmmkk kkkiiiii -// vtttttrr rrrppppp usssssqq qqqooooo - -__forceinline__ __device__ void shuffle_5bit_32 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0 * stride]; - uint32_t qb = q[1 * stride]; - uint32_t qc = q[2 * stride]; - uint32_t qd = q[3 * stride]; - uint32_t qe = q[4 * stride]; - - // qa: 66555554 44443333 32222211 11100000 - // qb: ccccbbbb baaaaa99 99988888 77777666 - // qc: jiiiiihh hhhggggg fffffeee eedddddc - // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj - // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp - - uint32_t qf = qe >> 22; - qe <<= 8; - qe |= qd >> 24; - qd <<= 6; - qd |= qc >> 26; - qc <<= 4; - qc |= qb >> 28; - qb <<= 2; - qb |= qa >> 30; - - // qa: 555554 44443333 32222211 11100000 - // qb: bbbbba aaaa9999 98888877 77766666 - // qc: hhhhhg ggggffff feeeeedd dddccccc - // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii - // qe: ttttts ssssrrrr rqqqqqpp pppooooo - // qf: vv vvvuuuuu - - uint32_t za = 0; - uint32_t zb = 0; - uint32_t zc = 0; - uint32_t zd = 0; - uint32_t ze = 0; - - for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); } - for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); } - for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); } - for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); } - for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); } - - // za: 5555533 33311111 4444422 22200000 - // zb: bbbbb99 99977777 aaaaa88 88866666 - // zc: hhhhhff fffddddd gggggee eeeccccc - // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii - // ze: tttttrr rrrppppp sssssqq qqqooooo - // qf: vv vvvuuuuu - - za |= ((qf & 0x001) >> 0) << 15; - zb |= ((qf & 0x002) >> 1) << 15; - zc |= ((qf & 0x004) >> 2) << 15; - zd |= ((qf & 0x008) >> 3) << 15; - ze |= ((qf & 0x010) >> 4) << 15; - za |= ((qf & 0x020) >> 5) << 31; - zb |= ((qf & 0x040) >> 6) << 31; - zc |= ((qf & 0x080) >> 7) << 31; - zd |= ((qf & 0x100) >> 8) << 31; - ze |= ((qf & 0x200) >> 9) << 31; - - // za: v5555533 33311111 u4444422 22200000 (u, v lsb) - // zb: vbbbbb99 99977777 uaaaaa88 88866666 - // zc: vhhhhhff fffddddd ugggggee eeeccccc - // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii - // ze: vtttttrr rrrppppp usssssqq qqqooooo - - q[0 * stride] = za; - q[1 * stride] = zb; - q[2 * stride] = zc; - q[3 * stride] = zd; - q[4 * stride] = ze; -} - -__forceinline__ __device__ void dequant_5bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - const uint32_t q_3, - const uint32_t q_4, - half2 (&dq)[16], - int stride -) -{ - const uint32_t c0 = 0x64006400; - const half y32_ = __float2half_rn(1.0f / 32.0f); - const half2 y32 = __halves2half2(y32_, y32_); - const half z1_ = __float2half_rn(-1024.0f - 16.0f); - const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f); - const half2 z1 = __halves2half2(z1_, z1_); - const half2 z32 = __halves2half2(z32_, z32_); - - uint32_t qa = q_0; - uint32_t qb = q_1; - uint32_t qc = q_2; - uint32_t qd = q_3; - uint32_t qe = q_4; - - half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024 - qa >>= 10; - half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024 - qa >>= 5; - qa &= 0x00010001; - half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024 - half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024 - qb >>= 10; - half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024 - qb >>= 4; - qb &= 0x00020002; - half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024 - half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024 - qc >>= 10; - half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024 - qc >>= 3; - qc &= 0x00040004; - half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024 - half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024 - qd >>= 10; - half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024 - qd >>= 2; - qd &= 0x00080008; - half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024 - half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024 - qe >>= 10; - half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024 - qe >>= 1; - qe &= 0x00100010; - half2_uint32 q15((qa | qb | qc | qd | qe) | c0); - - dq[ 0] = __hadd2( q0.as_half2, z1); - dq[ 1] = __hfma2( q1.as_half2, y32, z32); - dq[ 2] = __hadd2( q2.as_half2, z1); - dq[ 3] = __hadd2( q3.as_half2, z1); - dq[ 4] = __hfma2( q4.as_half2, y32, z32); - dq[ 5] = __hadd2( q5.as_half2, z1); - dq[ 6] = __hadd2( q6.as_half2, z1); - dq[ 7] = __hfma2( q7.as_half2, y32, z32); - dq[ 8] = __hadd2( q8.as_half2, z1); - dq[ 9] = __hadd2( q9.as_half2, z1); - dq[10] = __hfma2(q10.as_half2, y32, z32); - dq[11] = __hadd2(q11.as_half2, z1); - dq[12] = __hadd2(q12.as_half2, z1); - dq[13] = __hfma2(q13.as_half2, y32, z32); - dq[14] = __hadd2(q14.as_half2, z1); - dq[15] = __hadd2(q15.as_half2, z1); -} - -#else - -__forceinline__ __device__ void shuffle_5bit_32 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_5bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - const uint32_t q_3, - const uint32_t q_4, - half2 (&dq)[16], - int stride -) -{ - half dqh[32]; - for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16); - dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16); - for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16); - dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16); - for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16); - dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16); - for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16); - dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16); - for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16); - - for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); -} - -#endif - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh deleted file mode 100644 index 562fe695..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _qdq_6_cuh -#define _qdq_6_cuh - -#include "qdq_util.cuh" -#include "../../config.h" - -#if QMODE_6BIT == 1 - - // Not implemented - -#else - -__forceinline__ __device__ void shuffle_6bit_16 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_6bit_16 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - half2 (&dq)[8], - int stride -) -{ - half dqh[16]; - for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32); - dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); - for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32); - dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); - for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32); - - for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); -} - -#endif - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh deleted file mode 100644 index 6e6bedbd..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef _qdq_8_cuh -#define _qdq_8_cuh - -#include "qdq_util.cuh" -#include "../../config.h" - -#if QMODE_8BIT == 1 - - // Not implemented - -#else - -__forceinline__ __device__ void shuffle_8bit_4 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_8bit_8 -( - const uint32_t q_0, - const uint32_t q_1, - half2 (&dq)[4], - int stride -) -{ - half dqh[8]; - for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128); - for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); - - for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); -} - -#endif - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh deleted file mode 100644 index cac9df9c..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef _qdq_util_cuh -#define _qdq_util_cuh - -union half2_uint32 -{ - uint32_t as_uint32; - half2 as_half2; - __device__ half2_uint32(uint32_t val) : as_uint32(val) {} - __device__ half2_uint32(half2 val) : as_half2(val) {} - __device__ half2_uint32() : as_uint32(0) {} -}; - -union half_uint16 -{ - uint16_t as_uint16; - half as_half; - __device__ half_uint16(uint16_t val) : as_uint16(val) {} - __device__ half_uint16(half val) : as_half(val) {} - __device__ half_uint16() : as_uint16(0) {} -}; - -// Max_scale premultiplied by 1/256 - -__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) -{ - int qs_i = qs + 1; - half qs_h = __int2half_rn(qs_i * qs_i); - qs_h = __hmul(qs_h, max_scale); - return qs_h; -} - -__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) -{ - return __hmul(__int2half_rn(q - qzero), scale); -} - -__forceinline__ __device__ half dq_ns(const int q, const int qzero) -{ - //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); - return __int2half_rn(q - qzero); -} - -__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) -{ - return (int)((q >> shift) & mask); -} - -__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) -{ - return (int)(__funnelshift_rc(q0, q1, shift) & mask); -} - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh deleted file mode 100644 index e167bc23..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef _util_cuh -#define _util_cuh - -#include -#include -#include -#include -#include - -#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) - -#define DBGS(__x) printf("%s\n", __x) -#define DBGI(__x) printf("%s: %i\n", #__x, __x) -#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) -#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) -#define DBGX(__x) printf("%s: %x\n", #__x, __x) -#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y) -#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) -#define DBGF(__x) printf("%s: %f\n", #__x, __x) -#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) -#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) -#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x)) -#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y)) -#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z)) - -#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y)) -#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z)) - -__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale) -{ - half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f)); - qs_h = __hmul(qs_h, qs_h); - qs_h = __hmul(qs_h, max_scale); - return qs_h; -} - -__forceinline__ __device__ float clamp(float x, float a, float b) -{ - return fmaxf(a, fminf(b, x)); -} - -#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } -inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) -{ - if (code != cudaSuccess) - { - fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line); - if (abort) exit(code); - } -} - -void print_global_mem(const half* ptr, int rows, int columns, int stride); - -#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/ext.cpp b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/ext.cpp deleted file mode 100644 index ff4e1851..00000000 --- a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/ext.cpp +++ /dev/null @@ -1,139 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "config.h" - -#include "cuda/q_matrix.cuh" -#include "cuda/q_gemm.cuh" - -#include "cpp/util.h" - -// Some decluttering macros - -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") - - -// Quant matrix - -uintptr_t make_q_matrix -( - torch::Tensor q_weight, - torch::Tensor q_perm, - torch::Tensor q_invperm, - torch::Tensor q_scale, - torch::Tensor q_scale_max, - torch::Tensor q_groups, - torch::Tensor q_group_map, - torch::Tensor gptq_qzeros, - torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx, - torch::Tensor temp_dq -) -{ - TORCH_CHECK_DTYPE(q_weight, kInt); - TORCH_CHECK_DTYPE_OPT(q_perm, kShort); - TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); - TORCH_CHECK_DTYPE_OPT(q_scale, kInt); - TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); - TORCH_CHECK_DTYPE_OPT(q_groups, kShort); - TORCH_CHECK_DTYPE_OPT(q_group_map, kShort); - TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); - TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); - TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); - - TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); - - int device = q_weight.device().index(); - int width = q_weight.size(1); - int groups; - int height; - - if (!q_scale.device().is_meta()) - { - TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8); - TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1); - groups = q_scale.size(0); - height = q_invperm.size(0); - } - else - { - TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8); - TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1); - groups = gptq_qzeros.size(0); - height = q_weight.size(0) * 8; - } - - TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer") - - QMatrix* m = new QMatrix - ( - device, - height, - width, - groups, - (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), - q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), - q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), - q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), - q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), - q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(), - gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), - gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), - gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), - (half*) temp_dq.data_ptr() - ); - - if (m->failed) throw std::runtime_error("CUDA out of memory"); - - return reinterpret_cast (m); -} - -void gemm_half_q_half -( - torch::Tensor a, - uintptr_t b, - torch::Tensor c, - bool force_cuda -) -{ - QMatrix* qm = reinterpret_cast (b); - - TORCH_CHECK_DTYPE(a, kHalf); - TORCH_CHECK_DTYPE(c, kHalf); - TORCH_CHECK_SHAPES(a, 0, c, 0, 1); - TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes") - TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - - gemm_half_q_half_cuda - ( - at::cuda::getCurrentCUDABlasHandle(), - (const half*) a.data_ptr(), - qm, - (half*) c.data_ptr(), - c.size(0), // m - c.size(1), // n - a.size(1), // k - true, - NULL, - force_cuda - ); -} - -// Bindings - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("make_q_matrix", &make_q_matrix, "make_q_matrix"); - m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half"); -} diff --git a/backends/gaudi/server/exllamav2_kernels/setup.py b/backends/gaudi/server/exllamav2_kernels/setup.py deleted file mode 100644 index 56ffa973..00000000 --- a/backends/gaudi/server/exllamav2_kernels/setup.py +++ /dev/null @@ -1,30 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -import torch - -extra_cuda_cflags = ["-lineinfo", "-O3"] -extra_cflags = [] -if torch.version.hip: - extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] - extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"] - -extra_compile_args = { - "cxx": extra_cflags, - "nvcc": extra_cuda_cflags, -} - -setup( - name="exllamav2_kernels", - ext_modules=[ - CUDAExtension( - name="exllamav2_kernels", - sources=[ - "exllamav2_kernels/ext.cpp", - "exllamav2_kernels/cuda/q_matrix.cu", - "exllamav2_kernels/cuda/q_gemm.cu", - ], - extra_compile_args=extra_compile_args, - ) - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/backends/gaudi/server/requirements_cuda.txt b/backends/gaudi/server/requirements_cuda.txt deleted file mode 100644 index 5de75b6b..00000000 --- a/backends/gaudi/server/requirements_cuda.txt +++ /dev/null @@ -1,54 +0,0 @@ -certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" -charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" -click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") -deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" -grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" -idna==3.10 ; python_version >= "3.9" and python_version < "3.13" -importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" -loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" -mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" -prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" -protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" -py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" -pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" -regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" -rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" -scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" -sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" -setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" -tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" -typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" -win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" -wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" -zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/requirements_intel.txt b/backends/gaudi/server/requirements_intel.txt deleted file mode 100644 index 5de75b6b..00000000 --- a/backends/gaudi/server/requirements_intel.txt +++ /dev/null @@ -1,54 +0,0 @@ -certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" -charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" -click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") -deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" -grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" -idna==3.10 ; python_version >= "3.9" and python_version < "3.13" -importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" -loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" -mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" -prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" -protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" -py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" -pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" -regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" -rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" -scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" -sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" -setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" -tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" -typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" -win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" -wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" -zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/requirements_rocm.txt b/backends/gaudi/server/requirements_rocm.txt deleted file mode 100644 index 5de75b6b..00000000 --- a/backends/gaudi/server/requirements_rocm.txt +++ /dev/null @@ -1,54 +0,0 @@ -certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" -charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" -click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") -deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" -grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" -idna==3.10 ; python_version >= "3.9" and python_version < "3.13" -importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" -loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" -mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" -prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" -protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" -py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" -pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" -regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" -rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" -scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" -sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" -setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" -tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" -typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" -win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" -wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" -zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 4f2b9807..4d83a11f 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -3,7 +3,7 @@ import os from .common import Seqlen -if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": +if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": from .cuda import ( diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py index 21195d6a..8fda0517 100644 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -728,7 +728,7 @@ class CausalLM(Model): self.enable_hpu_graph = ( os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" if model.config.model_type not in [ "gpt_bigcode" @@ -790,9 +790,9 @@ class CausalLM(Model): if model.config.model_type not in ["gpt_bigcode"]: self.kwargs["trim_logits"] = True - if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": + if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": + if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": self.kwargs["flash_attention_recompute"] = True self.speculate = get_speculate() diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py index d2fbff54..f98dab91 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -85,8 +85,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, + use_flash_attention: Optional[bool] = True, + flash_attention_recompute: Optional[bool] = True, ): if token_idx is not None: @@ -156,8 +156,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): **kwargs, ) else: - use_flash_attention = kwargs.get("use_flash_attention", False) - flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + use_flash_attention = kwargs.get("use_flash_attention", True) + flash_attention_recompute = kwargs.get("flash_attention_recompute", True) position_ids = kwargs.get("position_ids", None) labels = kwargs.get("labels", None) diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index 181bc51a..d4f4c1af 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -621,7 +621,7 @@ class VlmCausalLM(Model): self.enable_hpu_graph = ( os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" model = remove_kv_cache_from_output(model) if self.enable_hpu_graph: from habana_frameworks.torch.hpu import wrap_in_hpu_graph @@ -668,9 +668,9 @@ class VlmCausalLM(Model): self.kwargs["attn_softmax_bf16"] = True self.kwargs["trim_logits"] = True - if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": + if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": + if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": self.kwargs["flash_attention_recompute"] = True self.speculate = get_speculate() diff --git a/backends/gaudi/server/text_generation_server/utils/__init__.py b/backends/gaudi/server/text_generation_server/utils/__init__.py index ead0e1f2..cda3a4da 100644 --- a/backends/gaudi/server/text_generation_server/utils/__init__.py +++ b/backends/gaudi/server/text_generation_server/utils/__init__.py @@ -20,6 +20,9 @@ from text_generation_server.utils.tokens import ( FinishReason, Sampling, Greedy, + make_tokenizer_optional, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, ) __all__ = [ @@ -41,4 +44,7 @@ __all__ = [ "StopSequenceCriteria", "FinishReason", "Weights", + "make_tokenizer_optional", + "is_tokenizer_transparent", + "pad_next_token_chooser_parameters", ] diff --git a/backends/gaudi/server/text_generation_server/utils/tokens.py b/backends/gaudi/server/text_generation_server/utils/tokens.py index b0282c42..9c44ba15 100644 --- a/backends/gaudi/server/text_generation_server/utils/tokens.py +++ b/backends/gaudi/server/text_generation_server/utils/tokens.py @@ -1,6 +1,5 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. -import os import re from typing import List, Optional, Tuple, Set, Union @@ -22,6 +21,7 @@ from text_generation_server.utils.logits_process import ( ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +import os class NextTokenChooser: @@ -753,8 +753,6 @@ def make_tokenizer_optional(tokenizer): # I don't think this method is used anywhere and should be removed when doing refactoring return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821 - import os - if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": tokenizer.__class__ = _ tokenizer.is_transparent = True diff --git a/backends/gaudi/server/text_generation_server/utils/version.py b/backends/gaudi/server/text_generation_server/utils/version.py index 84c916bf..f54b6ae8 100644 --- a/backends/gaudi/server/text_generation_server/utils/version.py +++ b/backends/gaudi/server/text_generation_server/utils/version.py @@ -1,7 +1,7 @@ from optimum.habana.utils import get_driver_version from packaging.version import Version -MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.16.0") +MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0") def is_driver_compatible(): diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh index ea94dcd9..a5c3f5e1 100644 --- a/backends/gaudi/tgi-entrypoint.sh +++ b/backends/gaudi/tgi-entrypoint.sh @@ -2,4 +2,10 @@ ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases' +# Check if --sharded argument is present in the command line arguments +if [[ "$*" == *"--sharded true"* ]]; then + echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding' + export PT_HPU_ENABLE_LAZY_COLLECTIVES=1 +fi + text-generation-launcher $@