mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
fix(gaudi): refactor server and implement requested changes
wip(gaudi): fix typos wip(gaudi): refactor version numbers for pytorch and habana software to make it more flexible wip(gaudi): debugging the refactored server wip(gaudi): delete useless files fix(gaudi): server working after refactoring fix(gaudi): refactor and implement requested changes
This commit is contained in:
parent
77dca4dfbe
commit
7cdbd694b3
2
.gitignore
vendored
2
.gitignore
vendored
@ -26,3 +26,5 @@ server/fbgemmm
|
||||
|
||||
# Gaudi auto-generated files
|
||||
hl-smi_log*.txt
|
||||
.graph_dumps
|
||||
out
|
||||
|
@ -1,3 +1,7 @@
|
||||
# Those arguments are required to build the image
|
||||
ARG HABANA_VERSION
|
||||
ARG PYTORCH_VERSION
|
||||
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
WORKDIR /usr/src
|
||||
@ -48,7 +52,10 @@ COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest AS base
|
||||
ARG HABANA_VERSION
|
||||
ARG PYTORCH_VERSION
|
||||
|
||||
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
|
||||
|
||||
ENV ATTENTION=default
|
||||
ENV PREFIX_CACHING=0
|
||||
@ -80,12 +87,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
COPY proto proto
|
||||
COPY backends/gaudi/server server
|
||||
COPY backends/gaudi/server/Makefile server/Makefile
|
||||
ARG HABANA_VERSION
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install --no-deps -r requirements.txt && \
|
||||
bash ./dill-0.3.8-patch.sh && \
|
||||
pip install outlines~=0.0.34 && \
|
||||
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0 && \
|
||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
|
||||
@ -108,6 +116,10 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base
|
||||
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
ENV HABANA_VISIBLE_DEVICES all
|
||||
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
||||
|
||||
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
@ -2,34 +2,33 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
|
||||
mkfile_dir := $(dir $(mkfile_path))
|
||||
root_dir := "${mkfile_dir}/../.."
|
||||
|
||||
HABANA_VERSION := 1.19.0
|
||||
PYTORCH_VERSION := 2.5.1
|
||||
|
||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||
|
||||
image:
|
||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir}
|
||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
||||
|
||||
run-local-dev-container:
|
||||
docker run -it \
|
||||
--runtime=habana \
|
||||
-e HABANA_VISIBLE_DEVICES=all \
|
||||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||
-e LOG_LEVEL=debug \
|
||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||
-e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
|
||||
-e ENABLE_HPU_GRAPH=true \
|
||||
-e LIMIT_HPU_GRAPH=true \
|
||||
-e USE_FLASH_ATTENTION=true \
|
||||
-e FLASH_ATTENTION_RECOMPUTE=true \
|
||||
-e PORT=8080 \
|
||||
--ipc=host \
|
||||
--cap-add=sys_nice \
|
||||
--net=host \
|
||||
--ipc=host \
|
||||
-e HABANA_VISIBLE_DEVICES=all \
|
||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||
-e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
|
||||
-e LOG_LEVEL=debug \
|
||||
-e PORT=8080 \
|
||||
-v /home/ubuntu/.cache/huggingface:/data \
|
||||
-v $(PWD):/text-generation-inference \
|
||||
-w /text-generation-inference \
|
||||
vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
|
||||
vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest
|
||||
|
||||
install-dependencies:
|
||||
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0
|
||||
pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)
|
||||
pip install outlines~=0.0.34
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
|
||||
|
@ -6,7 +6,7 @@ This is the TGI backend for Intel Gaudi. This backend is composed of the tgi ser
|
||||
|
||||
## Build your own image
|
||||
|
||||
The simplest way to build TGI with the gaudi backend is to use the provided `Makefile`:
|
||||
The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:
|
||||
|
||||
Option 1: From the project root directory:
|
||||
```bash
|
||||
@ -20,25 +20,39 @@ make image
|
||||
```
|
||||
|
||||
You can now run the server with the following command:
|
||||
|
||||
Option 1: Sharded:
|
||||
```bash
|
||||
model=meta-llama/Llama-3.1-8B-Instruct
|
||||
hf_token=$(cat ${HOME}/.cache/huggingface/token)
|
||||
volume=${HOME}/.cache/huggingface
|
||||
|
||||
docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||
-e LOG_LEVEL=debug \
|
||||
-e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||
-e HF_TOKEN=$hf_token -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true \
|
||||
-e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice \
|
||||
--ipc=host tgi-gaudi --model-id $model --sharded true \
|
||||
--num-shard 8 --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048 --max-batch-total-tokens 8192
|
||||
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-p 8080:80 -v $volume:/data \
|
||||
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||
tgi-gaudi --model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048
|
||||
```
|
||||
|
||||
Option 2: Non-sharded:
|
||||
```bash
|
||||
model=meta-llama/Llama-3.1-8B-Instruct
|
||||
hf_token=$(cat ${HOME}/.cache/huggingface/token)
|
||||
volume=${HOME}/.cache/huggingface
|
||||
|
||||
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-p 8080:80 -v $volume:/data \
|
||||
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||
tgi-gaudi --model-id $model \
|
||||
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
### Local Development
|
||||
|
||||
This is useful if you want to run the server in locally for better debugging.
|
||||
This is useful if you want to run the server locally for better debugging.
|
||||
```bash
|
||||
make -C backends/gaudi run-local-dev-container
|
||||
```
|
||||
|
@ -1,12 +0,0 @@
|
||||
exllamav2_commit := v0.1.8
|
||||
|
||||
build-exllamav2:
|
||||
git clone https://github.com/turboderp/exllamav2.git exllamav2 && \
|
||||
cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \
|
||||
git submodule update --init --recursive && \
|
||||
pip install -r requirements.txt && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build
|
||||
|
||||
install-exllamav2: build-exllamav2
|
||||
cd exllamav2/ && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install
|
@ -1,2 +0,0 @@
|
||||
install-flashinfer:
|
||||
pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4
|
@ -1,12 +0,0 @@
|
||||
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
||||
|
||||
build-lorax-punica:
|
||||
if [ ! -d 'lorax-punica' ]; then \
|
||||
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
|
||||
fi
|
||||
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
||||
cd lorax-punica && git submodule update --init --recursive
|
||||
cd lorax-punica/server/punica_kernels && python setup.py build
|
||||
|
||||
install-lorax-punica: build-lorax-punica
|
||||
cd lorax-punica/server/punica_kernels && python setup.py install
|
@ -1,250 +0,0 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include <optional>
|
||||
|
||||
/**
|
||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
|
||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
|
||||
**/
|
||||
|
||||
// Available in pytorch main
|
||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
|
||||
/*
|
||||
* Forward passes
|
||||
*/
|
||||
|
||||
/**
|
||||
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
|
||||
**/
|
||||
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
|
||||
__global__ void forward_masked_softmax_kernel(
|
||||
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
|
||||
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
|
||||
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
|
||||
const int64_t effective_kv_length,
|
||||
const dim3 blockDim,
|
||||
const int64_t rows_per_block,
|
||||
const int64_t kv_length,
|
||||
const int64_t batch_size
|
||||
) {
|
||||
const auto row_id = threadIdx.x / effective_kv_length;
|
||||
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
|
||||
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
|
||||
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
|
||||
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
|
||||
const auto kv_length_end = kv_length_end_;
|
||||
|
||||
const auto batch_id = blockIdx.x * rows_per_block + row_id;
|
||||
|
||||
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
|
||||
extern __shared__ float temp_storage[];
|
||||
const auto row_id_mem_offset = row_id * 2;
|
||||
if (effective_kv_length_id == 0) {
|
||||
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
|
||||
temp_storage[row_id_mem_offset + 1] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute mask and max
|
||||
if (batch_id < batch_size) {
|
||||
float thread_max = -std::numeric_limits<float>::infinity();
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
const float candidate = attention_scores[batch_id][kv_length_id];
|
||||
thread_max = (thread_max < candidate) ? candidate : thread_max;
|
||||
}
|
||||
}
|
||||
if (thread_max != -std::numeric_limits<float>::infinity()) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute exp(elt - max) masked
|
||||
float exponential[min_kv_length_shard_size_per_thread];
|
||||
if (batch_id < batch_size) {
|
||||
float thread_add = 0;
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
|
||||
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
|
||||
} else {
|
||||
exponential[kv_length_id - kv_length_start] = 0.;
|
||||
}
|
||||
}
|
||||
if (thread_add > 0) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute softmax
|
||||
if (batch_id < batch_size) {
|
||||
// If sum of all exponential is 0, we set the softmax values to 0
|
||||
if (temp_storage[row_id_mem_offset + 1] == 0.) {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = 0.;
|
||||
}
|
||||
} else {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
|
||||
const at::Tensor query,
|
||||
const at::Tensor key,
|
||||
const at::Tensor value,
|
||||
const std::optional<std::vector<at::Tensor>> layer_past,
|
||||
const at::Tensor attention_mask,
|
||||
const std::optional<at::Tensor> head_mask,
|
||||
const float inv_norm_factor,
|
||||
const int num_heads,
|
||||
const bool use_cache
|
||||
) {
|
||||
auto query_layer = query;
|
||||
auto key_layer = key;
|
||||
auto value_layer = value;
|
||||
|
||||
if (layer_past) {
|
||||
const auto past_key = (*layer_past).at(0);
|
||||
const auto past_value = (*layer_past).at(1);
|
||||
key_layer = at::cat({past_key, key_layer}, 2);
|
||||
value_layer = at::cat({past_value, value_layer}, 2);
|
||||
}
|
||||
|
||||
std::optional<std::vector<at::Tensor>> present;
|
||||
if (use_cache) {
|
||||
present = {key_layer, value_layer};
|
||||
} else {
|
||||
present = {};
|
||||
}
|
||||
|
||||
const auto batch_size = query_layer.size(0);
|
||||
const auto q_length = query_layer.size(2);
|
||||
const auto attn_head_size = query_layer.size(3);
|
||||
const auto batch_size_times_num_heads = batch_size * num_heads;
|
||||
const auto kv_length = key_layer.size(2);
|
||||
|
||||
const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size});
|
||||
auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2);
|
||||
auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size});
|
||||
|
||||
auto query_scaled = query_view * inv_norm_factor;
|
||||
auto attention_scores = at::bmm(query_scaled, key_view);
|
||||
|
||||
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
|
||||
at::Tensor attention_probs;
|
||||
if (true) {
|
||||
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
|
||||
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
|
||||
// Custom kernel
|
||||
attention_probs = at::empty_like(attention_scores_2d);
|
||||
|
||||
// Check that inputs and contiguous + cuda tensors
|
||||
CHECK_INPUT(attention_scores_2d);
|
||||
CHECK_INPUT(attention_mask_2d);
|
||||
|
||||
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
|
||||
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
/*
|
||||
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
|
||||
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
|
||||
* - SMs: 108
|
||||
* - TPCs: 56 (What's that?)
|
||||
* - Memory size: 40 GB
|
||||
* - L2 Cache size: 40960 KB (shared across all SMs)
|
||||
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
|
||||
* - Max Threads / SM: 2048
|
||||
* - Max Thread Blocks / SM: 32
|
||||
*/
|
||||
|
||||
/*
|
||||
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
|
||||
* with multiple threads as we need to `sync_threads` to run exponential sum.
|
||||
* We maximise the usage of threads within a single block
|
||||
*/
|
||||
// TODO @thomasw21 figure out everything warp related:
|
||||
// - why do they have to be power of 2
|
||||
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
|
||||
const auto MAX_THREADS_PER_SM = 1024;
|
||||
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
|
||||
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
|
||||
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
|
||||
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
|
||||
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
|
||||
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
|
||||
|
||||
const dim3 gridDim(num_blocks); // Number of blocks that run
|
||||
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
|
||||
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
|
||||
|
||||
// 192 * 2 ** 10
|
||||
// const auto MAX_L1_MEMORY = 196608;
|
||||
// const auto MAX_SMs = 108;
|
||||
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
|
||||
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
|
||||
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
|
||||
|
||||
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
|
||||
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
|
||||
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
effective_kv_length,
|
||||
blockDim,
|
||||
rows_per_block,
|
||||
kv_length,
|
||||
batch_size_times_num_heads * q_length
|
||||
);
|
||||
});
|
||||
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
|
||||
} else {
|
||||
// Pytorch C++ API
|
||||
auto input_dtype = attention_scores.scalar_type();
|
||||
if (input_dtype == at::ScalarType::Float) {
|
||||
attention_scores = attention_scores.to(at::ScalarType::Float);
|
||||
};
|
||||
// TODO @thomasw21 Figure out how to get minimum value
|
||||
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
|
||||
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
|
||||
}
|
||||
|
||||
auto context_layer = attention_probs.bmm(value_view);
|
||||
|
||||
// `_merge_heads`
|
||||
context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size});
|
||||
context_layer = context_layer.permute({0, 2, 1, 3});
|
||||
context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads});
|
||||
|
||||
return std::make_tuple(context_layer, present, attention_probs);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward",
|
||||
&forward,
|
||||
"GPT-Neox attention mechanism forward (CUDA)"
|
||||
);
|
||||
}
|
@ -1,250 +0,0 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include <optional>
|
||||
|
||||
/**
|
||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
|
||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
|
||||
**/
|
||||
|
||||
// Available in pytorch main
|
||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
|
||||
/*
|
||||
* Forward passes
|
||||
*/
|
||||
|
||||
/**
|
||||
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
|
||||
**/
|
||||
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
|
||||
__global__ void forward_masked_softmax_kernel(
|
||||
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
|
||||
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
|
||||
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
|
||||
const int64_t effective_kv_length,
|
||||
const dim3 blockDim,
|
||||
const int64_t rows_per_block,
|
||||
const int64_t kv_length,
|
||||
const int64_t batch_size
|
||||
) {
|
||||
const auto row_id = threadIdx.x / effective_kv_length;
|
||||
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
|
||||
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
|
||||
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
|
||||
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
|
||||
const auto kv_length_end = kv_length_end_;
|
||||
|
||||
const auto batch_id = blockIdx.x * rows_per_block + row_id;
|
||||
|
||||
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
|
||||
extern __shared__ float temp_storage[];
|
||||
const auto row_id_mem_offset = row_id * 2;
|
||||
if (effective_kv_length_id == 0) {
|
||||
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
|
||||
temp_storage[row_id_mem_offset + 1] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute mask and max
|
||||
if (batch_id < batch_size) {
|
||||
float thread_max = -std::numeric_limits<float>::infinity();
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
const float candidate = attention_scores[batch_id][kv_length_id];
|
||||
thread_max = (thread_max < candidate) ? candidate : thread_max;
|
||||
}
|
||||
}
|
||||
if (thread_max != -std::numeric_limits<float>::infinity()) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute exp(elt - max) masked
|
||||
float exponential[min_kv_length_shard_size_per_thread];
|
||||
if (batch_id < batch_size) {
|
||||
float thread_add = 0;
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
|
||||
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
|
||||
} else {
|
||||
exponential[kv_length_id - kv_length_start] = 0.;
|
||||
}
|
||||
}
|
||||
if (thread_add > 0) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute softmax
|
||||
if (batch_id < batch_size) {
|
||||
// If sum of all exponential is 0, we set the softmax values to 0
|
||||
if (temp_storage[row_id_mem_offset + 1] == 0.) {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = 0.;
|
||||
}
|
||||
} else {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
|
||||
const at::Tensor fused_qkv,
|
||||
const std::optional<std::vector<at::Tensor>> layer_past,
|
||||
const at::Tensor alibi,
|
||||
const at::Tensor attention_mask,
|
||||
const std::optional<at::Tensor> head_mask,
|
||||
const float beta,
|
||||
const float inv_norm_factor,
|
||||
const int num_heads,
|
||||
const bool use_cache
|
||||
) {
|
||||
const auto batch_size = fused_qkv.size(0);
|
||||
const auto q_length = fused_qkv.size(1);
|
||||
const auto three_times_hidden_size = fused_qkv.size(2);
|
||||
const auto head_dim = three_times_hidden_size / (3 * num_heads);
|
||||
const auto batch_size_times_num_heads = batch_size * num_heads;
|
||||
|
||||
// `split_heads`
|
||||
const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
|
||||
const auto tensor_list = fused_qkv_view.split(head_dim, -1);
|
||||
const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
|
||||
auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
|
||||
auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
|
||||
|
||||
if (layer_past) {
|
||||
const auto past_key = (*layer_past).at(0);
|
||||
const auto past_value = (*layer_past).at(1);
|
||||
key_layer = at::cat({past_key, key_layer}, 2);
|
||||
value_layer = at::cat({past_value, value_layer}, 1);
|
||||
}
|
||||
|
||||
std::optional<std::vector<at::Tensor>> present;
|
||||
if (use_cache) {
|
||||
present = {key_layer, value_layer};
|
||||
} else {
|
||||
present = {};
|
||||
}
|
||||
|
||||
auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
|
||||
|
||||
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
|
||||
at::Tensor attention_probs;
|
||||
if (true) {
|
||||
const auto kv_length = key_layer.size(2);
|
||||
|
||||
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
|
||||
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
|
||||
// Custom kernel
|
||||
attention_probs = at::empty_like(attention_scores_2d);
|
||||
|
||||
// Check that inputs and contiguous + cuda tensors
|
||||
CHECK_INPUT(attention_scores_2d);
|
||||
CHECK_INPUT(attention_mask_2d);
|
||||
|
||||
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
|
||||
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
/*
|
||||
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
|
||||
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
|
||||
* - SMs: 108
|
||||
* - TPCs: 56 (What's that?)
|
||||
* - Memory size: 40 GB
|
||||
* - L2 Cache size: 40960 KB (shared across all SMs)
|
||||
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
|
||||
* - Max Threads / SM: 2048
|
||||
* - Max Thread Blocks / SM: 32
|
||||
*/
|
||||
|
||||
/*
|
||||
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
|
||||
* with multiple threads as we need to `sync_threads` to run exponential sum.
|
||||
* We maximise the usage of threads within a single block
|
||||
*/
|
||||
// TODO @thomasw21 figure out everything warp related:
|
||||
// - why do they have to be power of 2
|
||||
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
|
||||
const auto MAX_THREADS_PER_SM = 1024;
|
||||
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
|
||||
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
|
||||
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
|
||||
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
|
||||
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
|
||||
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
|
||||
|
||||
const dim3 gridDim(num_blocks); // Number of blocks that run
|
||||
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
|
||||
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
|
||||
|
||||
// 192 * 2 ** 10
|
||||
// const auto MAX_L1_MEMORY = 196608;
|
||||
// const auto MAX_SMs = 108;
|
||||
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
|
||||
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
|
||||
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
|
||||
|
||||
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
|
||||
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
|
||||
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
effective_kv_length,
|
||||
blockDim,
|
||||
rows_per_block,
|
||||
kv_length,
|
||||
batch_size_times_num_heads * q_length
|
||||
);
|
||||
});
|
||||
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
|
||||
} else {
|
||||
// Pytorch C++ API
|
||||
auto input_dtype = attention_scores.scalar_type();
|
||||
if (input_dtype == at::ScalarType::Float) {
|
||||
attention_scores = attention_scores.to(at::ScalarType::Float);
|
||||
};
|
||||
// TODO @thomasw21 Figure out how to get minimum value
|
||||
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
|
||||
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
|
||||
}
|
||||
|
||||
auto context_layer = attention_probs.bmm(value_layer);
|
||||
|
||||
// `_merge_heads`
|
||||
context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
|
||||
context_layer = context_layer.permute({0, 2, 1, 3});
|
||||
context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
|
||||
|
||||
return std::make_tuple(context_layer, present, attention_probs);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward",
|
||||
&forward,
|
||||
"Bloom attention mechanism forward (CUDA)"
|
||||
);
|
||||
}
|
@ -1,21 +0,0 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
extra_compile_args = ["-std=c++17"]
|
||||
|
||||
setup(
|
||||
name="custom_kernels",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_bloom_attention_cuda",
|
||||
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
|
||||
extra_compile_args=extra_compile_args,
|
||||
),
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_attention_cuda",
|
||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||
extra_compile_args=extra_compile_args,
|
||||
),
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
@ -1,58 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_compat_cuh
|
||||
#define _cuda_compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,71 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#define _cuda_buffers_cu
|
||||
#include "cuda_buffers.cuh"
|
||||
|
||||
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
|
||||
// __constant__ half2 q4_table[16][256];
|
||||
// half2 q4_table_host[16][256];
|
||||
// bool q4_table_init = false;
|
||||
|
||||
CudaBuffers::CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
temp_state(_temp_state),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(_device);
|
||||
|
||||
cudaStreamCreate(&alt_stream_1);
|
||||
cudaStreamCreate(&alt_stream_2);
|
||||
cudaStreamCreate(&alt_stream_3);
|
||||
cudaEventCreate(&alt_stream_1_done);
|
||||
cudaEventCreate(&alt_stream_2_done);
|
||||
cudaEventCreate(&alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers::~CudaBuffers()
|
||||
{
|
||||
cudaStreamDestroy(alt_stream_1);
|
||||
cudaStreamDestroy(alt_stream_2);
|
||||
cudaStreamDestroy(alt_stream_3);
|
||||
cudaEventDestroy(alt_stream_1_done);
|
||||
cudaEventDestroy(alt_stream_2_done);
|
||||
cudaEventDestroy(alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index)
|
||||
{
|
||||
return g_buffers[device_index];
|
||||
}
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
)
|
||||
{
|
||||
CudaBuffers* buffers = new CudaBuffers
|
||||
(
|
||||
_device,
|
||||
_temp_state,
|
||||
_temp_dq
|
||||
);
|
||||
|
||||
g_buffers[_device] = buffers;
|
||||
}
|
||||
|
||||
void cleanup_buffers_cuda()
|
||||
{
|
||||
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
|
||||
{
|
||||
if (!g_buffers[i]) continue;
|
||||
delete g_buffers[i];
|
||||
g_buffers[i] = NULL;
|
||||
}
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_buffers_cuh
|
||||
#define _cuda_buffers_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
const int CUDA_MAX_DEVICES = 16;
|
||||
|
||||
// #ifndef _cuda_buffers_cu
|
||||
// extern __constant__ half2 q4_table[16][256];
|
||||
// #endif
|
||||
|
||||
class CudaBuffers
|
||||
{
|
||||
public:
|
||||
int device;
|
||||
|
||||
half* temp_state; // [max_hidden_rows * intermediate_size]
|
||||
half* temp_dq; // size of largest quant tensor * 8
|
||||
|
||||
cudaStream_t alt_stream_1;
|
||||
cudaStream_t alt_stream_2;
|
||||
cudaStream_t alt_stream_3;
|
||||
cudaEvent_t alt_stream_1_done;
|
||||
cudaEvent_t alt_stream_2_done;
|
||||
cudaEvent_t alt_stream_3_done;
|
||||
|
||||
CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
~CudaBuffers();
|
||||
};
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index);
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
void cleanup_buffers_cuda();
|
||||
|
||||
#endif
|
@ -1,61 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "column_remap.cuh"
|
||||
#include "../util.cuh"
|
||||
|
||||
const int SHUF_BLOCKSIZE_X = 256;
|
||||
const int SHUF_BLOCKSIZE_Y = 16;
|
||||
|
||||
__global__ void column_remap_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
half* __restrict__ x_new,
|
||||
const int x_width,
|
||||
const int x_height,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
|
||||
|
||||
int x_stride = x_width;
|
||||
int x_idx = x_row * x_stride + x_column;
|
||||
|
||||
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
|
||||
int x_idx_end = x_row_end * x_stride + x_column;
|
||||
|
||||
int s_column = x_map[x_column];
|
||||
int s_idx = x_row * x_stride + s_column;
|
||||
|
||||
while (x_idx < x_idx_end)
|
||||
{
|
||||
x_new[x_idx] = x[s_idx];
|
||||
x_idx += x_stride;
|
||||
s_idx += x_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Remap columns in x to correspond to sequential group index before matmul
|
||||
//
|
||||
// perform x -> seq_x such that seq_x @ seq_w == x @ w
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
|
||||
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
|
||||
1
|
||||
);
|
||||
|
||||
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _column_remap_cuh
|
||||
#define _column_remap_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
);
|
||||
|
||||
#endif
|
@ -1,256 +0,0 @@
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "../util.cuh"
|
||||
#include "../matrix.cuh"
|
||||
#include "../cu_compat.cuh"
|
||||
#include "../cuda_buffers.cuh"
|
||||
#if defined(USE_ROCM)
|
||||
#include "../hip_compat.cuh"
|
||||
#endif
|
||||
|
||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||
|
||||
typedef void (*fp_q4_matmul_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
half*,
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint32_t*,
|
||||
bool
|
||||
);
|
||||
|
||||
template<bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
__global__ void q4_matmul_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int dim,
|
||||
const int width,
|
||||
const int groupsize,
|
||||
const int block_size_z,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int x_column = block_size_z * blockIdx.z;
|
||||
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||
|
||||
int iterations = (x_column_end - x_column) / 8;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_half x_(x, height, dim);
|
||||
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
|
||||
MatrixView_q4_column w_(w, dim, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
|
||||
// Zero output
|
||||
|
||||
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Loop over part of x row (and w column)
|
||||
|
||||
half2 acc = {};
|
||||
half acc_h = {};
|
||||
|
||||
if constexpr (use_groupsize)
|
||||
{
|
||||
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
|
||||
// could be slightly faster
|
||||
|
||||
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
||||
|
||||
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add to block result
|
||||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half result = __hadd(__low2half(acc), __high2half(acc));
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||
}
|
||||
else
|
||||
{
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
|
||||
}
|
||||
}
|
||||
|
||||
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
|
||||
{
|
||||
// <bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
if (tuningParams->matmul_no_half2) {
|
||||
if (block_size_z % groupsize == 0) {
|
||||
if (x_map) return q4_matmul_kernel<false, true, true >;
|
||||
else return q4_matmul_kernel<false, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<false, false, true >;
|
||||
else return q4_matmul_kernel<false, false, false>;
|
||||
}
|
||||
} else {
|
||||
if (block_size_z % groupsize == 0)
|
||||
{
|
||||
if (x_map) return q4_matmul_kernel<true, true, true >;
|
||||
else return q4_matmul_kernel<true, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<true, false, true >;
|
||||
else return q4_matmul_kernel<true, false, false>;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Compute y = x @ w
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
cudaStream_t alt_stream
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
|
||||
uint32_t* x_map = w->cuda_x_map;
|
||||
const half* x_mapped = x;
|
||||
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
|
||||
{
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
x_map = NULL;
|
||||
}
|
||||
|
||||
int block_size_z;
|
||||
if (w->width == 4096) block_size_z = 384; // 7B
|
||||
else if (w->width == 11008) block_size_z = 256;
|
||||
else if (w->width == 5120) block_size_z = 384; // 13B
|
||||
else if (w->width == 13824) block_size_z = 256;
|
||||
else if (w->width == 6656) block_size_z = 256; // 33B
|
||||
else if (w->width == 17920) block_size_z = 128;
|
||||
else block_size_z = 256;
|
||||
|
||||
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
|
||||
|
||||
dim3 threads(THREADS_X, THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height + threads.y - 1) / threads.y,
|
||||
(dim + block_size_z - 1) / block_size_z
|
||||
);
|
||||
|
||||
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||
|
||||
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||
}
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
const cublasHandle_t handle
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
|
||||
const half* x_mapped = x;
|
||||
if (w->cuda_x_map)
|
||||
{
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
}
|
||||
|
||||
w->reconstruct(buffers->temp_dq);
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||
|
||||
// const float alpha = 1.0f;
|
||||
// const float beta = no_zero ? 1.0f : 0.0f;
|
||||
// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||
// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matmul_cuh
|
||||
#define _q4_matmul_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include "../tuning.h"
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
cudaStream_t alt_stream
|
||||
);
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
const cublasHandle_t handle
|
||||
);
|
||||
|
||||
#endif
|
@ -1,220 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "q4_matrix.cuh"
|
||||
#include <vector>
|
||||
#include "../util.cuh"
|
||||
#include "../matrix.cuh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
const int UNSHUF_BLOCKSIZE_X = 64;
|
||||
|
||||
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
|
||||
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
|
||||
|
||||
vector<Q4Matrix*> g_q4_matrices;
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m)
|
||||
{
|
||||
g_q4_matrices.push_back(m);
|
||||
}
|
||||
|
||||
void g_q4_free_matrices()
|
||||
{
|
||||
for (const auto& m : g_q4_matrices) delete m;
|
||||
g_q4_matrices.clear();
|
||||
}
|
||||
|
||||
Q4Matrix::Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
) :
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
device(_device)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
cuda_qweight = _qweight;
|
||||
cuda_qzeros = _qzeros;
|
||||
cuda_scales = _scales;
|
||||
|
||||
groupsize = height / groups;
|
||||
|
||||
if (_g_idx) make_sequential(_g_idx);
|
||||
}
|
||||
|
||||
Q4Matrix::~Q4Matrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Make sequential
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int x_map_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = x_map[x_map_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
make_sequential_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
}
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out, // (y)
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int width,
|
||||
const int groupsize
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
MatrixView_half w_scales_(w_scales, height / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
|
||||
|
||||
// Groupsize version
|
||||
|
||||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F;
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < 32; s += 4)
|
||||
{
|
||||
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||
*out_ptr = w_item; out_ptr += out_.width;
|
||||
}
|
||||
}
|
||||
|
||||
void Q4Matrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height / 8 + threads.y - 1) / threads.y,
|
||||
1
|
||||
);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
reconstruct_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matrix_cuh
|
||||
#define _q4_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
class Q4Matrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
uint32_t* cuda_qweight = NULL;
|
||||
uint32_t* cuda_qzeros = NULL;
|
||||
half* cuda_scales = NULL;
|
||||
uint32_t* cuda_x_map = NULL;
|
||||
|
||||
Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
);
|
||||
|
||||
~Q4Matrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
|
||||
private:
|
||||
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
};
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m);
|
||||
void g_q4_free_matrices();
|
||||
|
||||
#endif
|
@ -1,253 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include "util.cuh"
|
||||
#include "tuning.h"
|
||||
#include "cuda_buffers.cuh"
|
||||
#include "cuda_func/q4_matrix.cuh"
|
||||
#include "cuda_func/q4_matmul.cuh"
|
||||
#include "cuda_func/column_remap.cuh"
|
||||
|
||||
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
||||
|
||||
void check_cuda(cudaError_t ret)
|
||||
{
|
||||
switch (ret)
|
||||
{
|
||||
case cudaSuccess:
|
||||
break;
|
||||
|
||||
case cudaUnspecified:
|
||||
printf(" **** Unspecified error\n");
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
|
||||
default:
|
||||
printf(" **** CUDA error\n"); \
|
||||
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
||||
TORCH_CHECK(false, "CUDA error"); \
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define STRINGIFY_(__x) #__x
|
||||
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||
|
||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||
do { \
|
||||
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||
} while(0)
|
||||
|
||||
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||
do { \
|
||||
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||
} while(0)
|
||||
|
||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||
{
|
||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||
return groupsize;
|
||||
}
|
||||
|
||||
|
||||
// Tuning parameters
|
||||
|
||||
ExLlamaTuning tuningParams;
|
||||
|
||||
void set_tuning_params
|
||||
(
|
||||
int matmul_recons_thd,
|
||||
bool matmul_fused_remap,
|
||||
bool matmul_no_half2
|
||||
)
|
||||
{
|
||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||
}
|
||||
|
||||
|
||||
// Release all unmanaged objects allocated by the extension
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
cleanup_buffers_cuda();
|
||||
g_q4_free_matrices();
|
||||
}
|
||||
|
||||
|
||||
// Prepare buffers for forward pass
|
||||
|
||||
void prepare_buffers
|
||||
(
|
||||
torch::Device device,
|
||||
torch::Tensor temp_state,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
int device_index = device.index();
|
||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
|
||||
prepare_buffers_cuda
|
||||
(
|
||||
device_index,
|
||||
(half*) temp_state.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Create Q4Matrix, return handle
|
||||
|
||||
uintptr_t make_q4
|
||||
(
|
||||
torch::Tensor qweight,
|
||||
torch::Tensor qzeros,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor g_idx,
|
||||
int device
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||
|
||||
int width = qweight.size(1);
|
||||
int height = qweight.size(0) * 8;
|
||||
int groups = qzeros.size(0);
|
||||
|
||||
Q4Matrix* m = new Q4Matrix
|
||||
(
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
|
||||
(uint32_t*) qweight.data_ptr(),
|
||||
(uint32_t*) qzeros.data_ptr(),
|
||||
(half*) scales.data_ptr(),
|
||||
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||
|
||||
device
|
||||
);
|
||||
|
||||
g_q4_keep_matrix(m);
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
|
||||
// Matmul half @ quant -> half
|
||||
|
||||
void q4_matmul
|
||||
(
|
||||
torch::Tensor x,
|
||||
uintptr_t w,
|
||||
torch::Tensor out
|
||||
)
|
||||
{
|
||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
int x_height = x.size(0);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||
{
|
||||
q4_matmul_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr(),
|
||||
false,
|
||||
stream
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
q4_matmul_recons_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr(),
|
||||
false,
|
||||
at::cuda::getCurrentCUDABlasHandle()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remap columns in half tensor
|
||||
|
||||
void column_remap
|
||||
(
|
||||
torch::Tensor x,
|
||||
torch::Tensor x_new,
|
||||
torch::Tensor x_map
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||
|
||||
int height = x.size(0);
|
||||
int width = x.size(1);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
column_remap_cuda
|
||||
(
|
||||
(half*) x.data_ptr(),
|
||||
(half*) x_new.data_ptr(),
|
||||
height,
|
||||
width,
|
||||
(uint32_t*) x_map.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _hip_compat_cuh
|
||||
#define _hip_compat_cuh
|
||||
|
||||
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
|
||||
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||
return __half_raw{
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||
return _Float16_2{
|
||||
_Float16_2{static_cast<_Float16>(1.0f),
|
||||
static_cast<_Float16>(1.0f)} / x.data};
|
||||
}
|
||||
|
||||
#define hrcp __compat_hrcp
|
||||
#define h2rcp __compat_h2rcp
|
||||
|
||||
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_get_stream hipblasGetStream
|
||||
#define rocblas_set_stream hipblasSetStream
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
|
||||
#endif
|
@ -1,294 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _matrix_cuh
|
||||
#define _matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
||||
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
||||
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
||||
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
||||
|
||||
half2 tmp = __hmul2(*h_ptr++, v_01);
|
||||
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(*h_ptr++, v_0);
|
||||
tmp = __hfma(*h_ptr++, v_1, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_2, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_3, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_4, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_5, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_6, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
half h_0 = h_ptr[*x_map_ptr++];
|
||||
half h_1 = h_ptr[*x_map_ptr++];
|
||||
half h_2 = h_ptr[*x_map_ptr++];
|
||||
half h_3 = h_ptr[*x_map_ptr++];
|
||||
half h_4 = h_ptr[*x_map_ptr++];
|
||||
half h_5 = h_ptr[*x_map_ptr++];
|
||||
half h_6 = h_ptr[*x_map_ptr++];
|
||||
half h_7 = h_ptr[*x_map_ptr++];
|
||||
|
||||
half2 h_01 = __halves2half2(h_0, h_1);
|
||||
half2 h_23 = __halves2half2(h_2, h_3);
|
||||
half2 h_45 = __halves2half2(h_4, h_5);
|
||||
half2 h_67 = __halves2half2(h_6, h_7);
|
||||
|
||||
half2 tmp = __hmul2(h_01, v_01);
|
||||
tmp = __hfma2(h_23, v_23, tmp);
|
||||
tmp = __hfma2(h_45, v_45, tmp);
|
||||
tmp = __hfma2(h_67, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
@ -1,13 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _tuning_h
|
||||
#define _tuning_h
|
||||
|
||||
struct ExLlamaTuning
|
||||
{
|
||||
int matmul_recons_thd;
|
||||
bool matmul_fused_remap;
|
||||
bool matmul_no_half2;
|
||||
};
|
||||
|
||||
#endif
|
@ -1,33 +0,0 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _util_cuh
|
||||
#define _util_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define cudaUnspecified hipErrorUnknown
|
||||
#else
|
||||
#define cudaUnspecified cudaErrorApiFailureBase
|
||||
#endif
|
||||
|
||||
// React to failure on return code != cudaSuccess
|
||||
|
||||
#define _cuda_check(fn) \
|
||||
do { \
|
||||
{_cuda_err = fn;} \
|
||||
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
|
||||
} while(false)
|
||||
|
||||
// React to failure on return code == 0
|
||||
|
||||
#define _alloc_check(fn) \
|
||||
do { \
|
||||
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
||||
else _cuda_err = cudaSuccess; \
|
||||
} while(false)
|
||||
|
||||
#endif
|
@ -1,32 +0,0 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = []
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
setup(
|
||||
name="exllama_kernels",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="exllama_kernels",
|
||||
sources=[
|
||||
"exllama_kernels/exllama_ext.cpp",
|
||||
"exllama_kernels/cuda_buffers.cu",
|
||||
"exllama_kernels/cuda_func/column_remap.cu",
|
||||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||
"exllama_kernels/cuda_func/q4_matrix.cu",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
@ -1,15 +0,0 @@
|
||||
#ifndef _config_h
|
||||
#define _config_h
|
||||
|
||||
#define MAX_Q_GEMM_ROWS 50
|
||||
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
|
||||
|
||||
#define QMODE_2BIT 1
|
||||
#define QMODE_3BIT 1
|
||||
#define QMODE_4BIT 1
|
||||
#define QMODE_5BIT 1
|
||||
#define QMODE_6BIT 0
|
||||
#define QMODE_8BIT 0
|
||||
|
||||
|
||||
#endif
|
@ -1,12 +0,0 @@
|
||||
#ifndef _util_h
|
||||
#define _util_h
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
|
||||
#endif
|
@ -1,56 +0,0 @@
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,121 +0,0 @@
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "quant/qdq_util.cuh"
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
@ -1,220 +0,0 @@
|
||||
#include "q_gemm.cuh"
|
||||
#include "util.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "../config.h"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define GPTQ_BLOCK_KN_SIZE 128
|
||||
#define GPTQ_BLOCK_M_SIZE_MAX 8
|
||||
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
|
||||
|
||||
#define EXL2_BLOCK_KN_SIZE 64
|
||||
#define EXL2_BLOCK_M_SIZE_MAX 8
|
||||
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
|
||||
|
||||
#define CLEAR_N_SIZE 256
|
||||
|
||||
#include "q_gemm_kernel.cuh"
|
||||
#include "q_gemm_kernel_gptq.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int m_count,
|
||||
bool clear,
|
||||
const half* r_weights,
|
||||
int r_weights_stride,
|
||||
bool mul_r_weights
|
||||
)
|
||||
{
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (!b->is_gptq)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = EXL2_BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_q_scale,
|
||||
b->cuda_q_scale_max,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->cuda_q_group_map,
|
||||
b->cuda_q_perm,
|
||||
b->rows_8,
|
||||
b->rows_6,
|
||||
b->rows_5,
|
||||
b->rows_4,
|
||||
b->rows_3,
|
||||
b->rows_2,
|
||||
clear,
|
||||
r_weights,
|
||||
r_weights_stride
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = GPTQ_BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||
|
||||
// DBGX((uint64_t) r_weights);
|
||||
// if (r_weights)
|
||||
// print_global_mem(r_weights, 1, 1, 1);
|
||||
// DBGI(r_weights_stride);
|
||||
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_gptq_qzeros,
|
||||
b->cuda_gptq_scales,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->gptq_groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_4,
|
||||
clear,
|
||||
r_weights,
|
||||
r_weights_stride
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear,
|
||||
half* temp_dq,
|
||||
bool force_cuda,
|
||||
const half* r_weights,
|
||||
const int r_weights_stride,
|
||||
bool mul_r_weights
|
||||
)
|
||||
{
|
||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||
{
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
|
||||
if (!temp_dq) temp_dq = b->temp_dq;
|
||||
b->reconstruct(temp_dq);
|
||||
|
||||
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
|
||||
cublasHgemm(cublas_handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
size_n, size_m, size_k,
|
||||
&alpha, temp_dq, size_n,
|
||||
a, size_k,
|
||||
&beta, c, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasSgemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N,
|
||||
// CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasGemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n,
|
||||
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Quantized matmul
|
||||
|
||||
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
|
||||
int max_chunks = size_m / block_m_size_max;
|
||||
int last_chunk = max_chunks * block_m_size_max;
|
||||
int last_chunk_size = size_m - last_chunk;
|
||||
|
||||
if (max_chunks)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||
}
|
||||
|
||||
if (last_chunk_size)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void clear_kernel
|
||||
(
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n
|
||||
)
|
||||
{
|
||||
int m = blockIdx.y;
|
||||
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
|
||||
if (n >= size_n) return;
|
||||
int4* c_ptr = (int4*)(c + m * size_n + n);
|
||||
*c_ptr = {};
|
||||
}
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
)
|
||||
{
|
||||
// dim3 blockDim, gridDim;
|
||||
// blockDim.x = CLEAR_N_SIZE;
|
||||
// blockDim.y = 1;
|
||||
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||
// gridDim.y = size_m;
|
||||
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||
}
|
@ -1,36 +0,0 @@
|
||||
#ifndef _q_gemm_cuh
|
||||
#define _q_gemm_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q_matrix.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear = false,
|
||||
half* reconstruct = NULL,
|
||||
bool force_cuda = false,
|
||||
const half* r_weights = NULL,
|
||||
const int r_weights_stride = 0,
|
||||
bool mul_r_weights = false
|
||||
);
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
);
|
||||
|
||||
#endif
|
@ -1,580 +0,0 @@
|
||||
#include "compat.cuh"
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
|
||||
{
|
||||
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
|
||||
|
||||
float result = {};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half2 w01 = dq[i];
|
||||
float w0 = __low2float(w01);
|
||||
float w1 = __high2float(w01);
|
||||
float x0 = __half2float(*a_ptr++);
|
||||
float x1 = __half2float(*a_ptr++);
|
||||
result = fma(w0, x0, result);
|
||||
result = fma(w1, x1, result);
|
||||
}
|
||||
float qs = __half2float(qs_h);
|
||||
result *= qs;
|
||||
half result_h = __float2half_rn(result);
|
||||
return __hadd(result_h, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||
return __hfma(result_h, qs_h, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
half result_h = __hadd(__low2half(result), __high2half(result));
|
||||
return __hfma(result_h, qs_h, g_result);
|
||||
}
|
||||
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const bool,
|
||||
const half*,
|
||||
const int
|
||||
);
|
||||
|
||||
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||
__global__ void gemm_half_q_half_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const uint16_t* __restrict__ b_q_group_map,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2,
|
||||
const bool clear,
|
||||
const half* r_weights,
|
||||
const int r_weights_stride
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Read weights
|
||||
|
||||
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||
if constexpr (use_r_weights)
|
||||
{
|
||||
uint16_t any_w = 0;
|
||||
const half* w_ptr = r_weights;
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
weights[m].as_half = *w_ptr;
|
||||
w_ptr += r_weights_stride;
|
||||
any_w |= weights[m].as_uint16;
|
||||
}
|
||||
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||
}
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
// half a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
//int group = offset_k / groupsize;
|
||||
int group = b_q_group_map[offset_k * 2];
|
||||
|
||||
// if (offset_m == 0 && t == 0)
|
||||
// DBGI2(offset_k, group);
|
||||
|
||||
// Preload scales
|
||||
|
||||
half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
|
||||
|
||||
//int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||
int temp_k = offset_k;
|
||||
for (int g = 0; temp_k < end_k; g++)
|
||||
{
|
||||
int qscales[4];
|
||||
b_q_scale_.item4(qscales, group + g, n);
|
||||
qscales[0]++;
|
||||
qscales[1]++;
|
||||
qscales[2]++;
|
||||
qscales[3]++;
|
||||
half maxscale = b_q_scale_max[group + g];
|
||||
scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
|
||||
scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
|
||||
scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
|
||||
scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
|
||||
temp_k += b_q_group_map[temp_k * 2 + 1];
|
||||
}
|
||||
|
||||
// a, b offset
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = EXL2_BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int scales_idx = 0;
|
||||
half qs_h0 = scales[scales_idx][0];
|
||||
half qs_h1 = scales[scales_idx][1];
|
||||
half qs_h2 = scales[scales_idx][2];
|
||||
half qs_h3 = scales[scales_idx][3];
|
||||
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||
|
||||
// Column result
|
||||
|
||||
half block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize groups
|
||||
|
||||
int k = offset_k;
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[2];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
|
||||
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
|
||||
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
|
||||
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[5];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
|
||||
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
|
||||
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
|
||||
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
|
||||
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
|
||||
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
|
||||
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_h0 = scales[scales_idx][0];
|
||||
qs_h1 = scales[scales_idx][1];
|
||||
qs_h2 = scales[scales_idx][2];
|
||||
qs_h3 = scales[scales_idx][3];
|
||||
nextgroup += b_q_group_map[k * 2 + 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
|
||||
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
|
||||
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
|
||||
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
|
||||
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
|
||||
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
|
||||
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
|
||||
}
|
||||
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 16;
|
||||
}
|
||||
|
||||
// Accumulate column sums in c
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
|
||||
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
|
||||
|
||||
if constexpr (mul_r_weights)
|
||||
{
|
||||
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||
result01 = __hmul2(result01, w_mul2);
|
||||
result23 = __hmul2(result23, w_mul2);
|
||||
}
|
||||
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
// *out = result01;
|
||||
// *(out + 1) = result23;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool use_r_weights, bool mul_r_weights>
|
||||
struct map_m_count_exl2 {
|
||||
static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
|
||||
{
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if EXL2_BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
};
|
||||
|
||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||
{
|
||||
if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||
if (!r_weights && mul_r_weights) return map_m_count_exl2<false, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||
if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);
|
||||
if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count);
|
||||
return NULL;
|
||||
}
|
@ -1,273 +0,0 @@
|
||||
#include "compat.cuh"
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hadd2(result, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const bool,
|
||||
const half*,
|
||||
const int
|
||||
);
|
||||
|
||||
template <int m_count, bool use_r_weights, bool mul_r_weights>
|
||||
__global__ void gemm_half_q_half_gptq_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_4,
|
||||
const bool clear,
|
||||
const half* r_weights,
|
||||
const int r_weights_stride
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Read weights
|
||||
|
||||
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
|
||||
if constexpr (use_r_weights)
|
||||
{
|
||||
uint16_t any_w = 0;
|
||||
const half* w_ptr = r_weights;
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
weights[m].as_half = *w_ptr;
|
||||
w_ptr += r_weights_stride;
|
||||
any_w |= weights[m].as_uint16;
|
||||
}
|
||||
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
|
||||
}
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
|
||||
half a0;
|
||||
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
else a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Zero output
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// a, b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = GPTQ_BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int zeros[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
|
||||
// __syncthreads();
|
||||
|
||||
// Column result
|
||||
|
||||
half2 block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize and multiply
|
||||
|
||||
int k = offset_k;
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
|
||||
block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
}
|
||||
|
||||
b_ptr += size_n;
|
||||
a_ptr += 8;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||
half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
|
||||
half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
|
||||
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
|
||||
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
|
||||
half2 result01 = __halves2half2(result0, result1);
|
||||
half2 result23 = __halves2half2(result2, result3);
|
||||
|
||||
if constexpr (mul_r_weights)
|
||||
{
|
||||
half2 w_mul2 = __half2half2(weights[m].as_half);
|
||||
result01 = __hmul2(result01, w_mul2);
|
||||
result23 = __hmul2(result23, w_mul2);
|
||||
}
|
||||
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool use_r_weights, bool mul_r_weights>
|
||||
struct map_m_count_gptq {
|
||||
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
|
||||
{
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
};
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
|
||||
{
|
||||
if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
|
||||
return NULL;
|
||||
}
|
@ -1,650 +0,0 @@
|
||||
#include "q_matrix.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "util.cuh"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
|
||||
#define THREADS_X 32
|
||||
#define THREADS_Y 32
|
||||
|
||||
// Shuffle quantized data on load
|
||||
|
||||
__global__ void shuffle_kernel
|
||||
(
|
||||
uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
|
||||
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
|
||||
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
|
||||
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
|
||||
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
|
||||
}
|
||||
|
||||
|
||||
// QMatrix constructor
|
||||
|
||||
QMatrix::QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
uint16_t* _q_group_map,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
failed = false;
|
||||
|
||||
cuda_q_weight = _q_weight;
|
||||
cuda_q_perm = _q_perm;
|
||||
cuda_q_invperm = _q_invperm;
|
||||
cuda_q_scale = _q_scale;
|
||||
cuda_q_scale_max = _q_scale_max;
|
||||
cuda_q_groups = _q_groups;
|
||||
cuda_q_group_map = _q_group_map;
|
||||
cuda_gptq_qzeros = _gptq_qzeros;
|
||||
cuda_gptq_scales = _gptq_scales;
|
||||
|
||||
is_gptq = (_gptq_qzeros != NULL);
|
||||
|
||||
if (is_gptq)
|
||||
{
|
||||
gptq_groupsize = 1;
|
||||
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
|
||||
}
|
||||
|
||||
// Create group map
|
||||
|
||||
rows_8 = 0;
|
||||
rows_6 = 0;
|
||||
rows_5 = 0;
|
||||
rows_4 = 0;
|
||||
rows_3 = 0;
|
||||
rows_2 = 0;
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
int row = 0;
|
||||
for (int i = 0; i < groups; i++)
|
||||
{
|
||||
int bits = cpu_q_groups[i * 2];
|
||||
|
||||
int rows;
|
||||
if (i < groups - 1)
|
||||
{
|
||||
int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
|
||||
rows = qrows * 32 / bits;
|
||||
}
|
||||
else rows = height - row;
|
||||
|
||||
if (bits == 8) rows_8 += rows;
|
||||
if (bits == 6) rows_6 += rows;
|
||||
if (bits == 5) rows_5 += rows;
|
||||
if (bits == 4) rows_4 += rows;
|
||||
if (bits == 3) rows_3 += rows;
|
||||
if (bits == 2) rows_2 += rows;
|
||||
row += rows;
|
||||
}
|
||||
|
||||
free(cpu_q_groups);
|
||||
|
||||
rows_6 += rows_8;
|
||||
rows_5 += rows_6;
|
||||
rows_4 += rows_5;
|
||||
rows_3 += rows_4;
|
||||
rows_2 += rows_3;
|
||||
}
|
||||
else
|
||||
{
|
||||
rows_4 = height;
|
||||
rows_3 = height;
|
||||
rows_2 = height;
|
||||
|
||||
if (_gptq_g_idx)
|
||||
{
|
||||
if (!make_sequential(_gptq_g_idx))
|
||||
{
|
||||
failed = true;
|
||||
//printf("FAIL\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DBGI(rows_8);
|
||||
// DBGI(rows_6);
|
||||
// DBGI(rows_5);
|
||||
// DBGI(rows_4);
|
||||
// DBGI(rows_3);
|
||||
// DBGI(rows_2);
|
||||
|
||||
// Shuffle quantized data
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = 1;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||
}
|
||||
|
||||
QMatrix::~QMatrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Reconstruct b[k,n] (GPTQ)
|
||||
|
||||
__global__ void reconstruct_gptq_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_4
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
|
||||
if (b_q_perm)
|
||||
{
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
}
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
// Initial zeros/scale
|
||||
|
||||
int zeros[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4][4];
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
b_ptr += size_n;
|
||||
//half* dqh = (half*)dq;
|
||||
if (b_q_perm)
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct b[k,n]
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
const uint16_t* __restrict__ b_q_group_map,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
//const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
int t = threadIdx.x;
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
// int group = offset_k / groupsize;
|
||||
int group = b_q_group_map[offset_k * 2];
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||
int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
dequant_8bit_8(q_0, q_1, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_3 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_4 = *b_ptr; b_ptr += size_n;
|
||||
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_4bit_8(q_0, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_2bit_16(q_0, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 16;
|
||||
}
|
||||
}
|
||||
|
||||
void QMatrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_q_scale,
|
||||
cuda_q_scale_max,
|
||||
cuda_q_group_map,
|
||||
height,
|
||||
width,
|
||||
//groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_8,
|
||||
rows_6,
|
||||
rows_5,
|
||||
rows_4,
|
||||
rows_3,
|
||||
rows_2
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
||||
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_gptq_qzeros,
|
||||
cuda_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
height,
|
||||
width,
|
||||
gptq_groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_4
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint16_t* __restrict__ q_perm,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = q_perm[q_perm_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
if (err != cudaSuccess) {
|
||||
cudaError_t cuda_status = cudaGetLastError(); // Clear error
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Reduce to uint16_t
|
||||
|
||||
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
|
||||
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
|
||||
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
|
||||
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = height / 8;
|
||||
|
||||
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_new_qweight,
|
||||
cuda_q_perm,
|
||||
height / 8,
|
||||
width
|
||||
);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
|
||||
return true;
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
#ifndef _q_matrix_cuh
|
||||
#define _q_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#define MAX_SUPERGROUPS 16
|
||||
|
||||
class QMatrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
bool is_gptq;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int gptq_groupsize;
|
||||
|
||||
int rows_8;
|
||||
int rows_6;
|
||||
int rows_5;
|
||||
int rows_4;
|
||||
int rows_3;
|
||||
int rows_2;
|
||||
|
||||
uint32_t* cuda_q_weight = NULL;
|
||||
uint16_t* cuda_q_perm = NULL;
|
||||
uint16_t* cuda_q_invperm = NULL;
|
||||
uint32_t* cuda_q_scale = NULL;
|
||||
half* cuda_q_scale_max = NULL;
|
||||
uint16_t* cuda_q_groups = NULL;
|
||||
uint16_t* cuda_q_group_map = NULL;
|
||||
uint32_t* cuda_gptq_qzeros = NULL;
|
||||
half* cuda_gptq_scales = NULL;
|
||||
|
||||
half* temp_dq;
|
||||
|
||||
bool failed;
|
||||
|
||||
QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
uint16_t* _q_group_map,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
~QMatrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
bool make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
private:
|
||||
|
||||
};
|
||||
|
||||
#endif
|
@ -1,103 +0,0 @@
|
||||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_2BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
|
||||
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z4 = __halves2half2(z4_, z4_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,169 +0,0 @@
|
||||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_3BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
|
||||
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
|
||||
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
|
||||
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,227 +0,0 @@
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_4BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z = __hmul(z, scale);
|
||||
z1[0] = __half2half2(z);
|
||||
y1[0] = __half2half2(scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1)[2],
|
||||
half2(&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z1[0] = __half2half2(z);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
half2 dqh2[8];
|
||||
|
||||
uint32_t qa = q_0;
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
dqh2[i] = __halves2half2(d0, d1);
|
||||
}
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,207 +0,0 @@
|
||||
#ifndef _qdq_5_cuh
|
||||
#define _qdq_5_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_5BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
uint32_t qd = q[3 * stride];
|
||||
uint32_t qe = q[4 * stride];
|
||||
|
||||
// qa: 66555554 44443333 32222211 11100000
|
||||
// qb: ccccbbbb baaaaa99 99988888 77777666
|
||||
// qc: jiiiiihh hhhggggg fffffeee eedddddc
|
||||
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
|
||||
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
|
||||
|
||||
uint32_t qf = qe >> 22;
|
||||
qe <<= 8;
|
||||
qe |= qd >> 24;
|
||||
qd <<= 6;
|
||||
qd |= qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: 555554 44443333 32222211 11100000
|
||||
// qb: bbbbba aaaa9999 98888877 77766666
|
||||
// qc: hhhhhg ggggffff feeeeedd dddccccc
|
||||
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
|
||||
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
uint32_t zd = 0;
|
||||
uint32_t ze = 0;
|
||||
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
|
||||
|
||||
// za: 5555533 33311111 4444422 22200000
|
||||
// zb: bbbbb99 99977777 aaaaa88 88866666
|
||||
// zc: hhhhhff fffddddd gggggee eeeccccc
|
||||
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
|
||||
// ze: tttttrr rrrppppp sssssqq qqqooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
za |= ((qf & 0x001) >> 0) << 15;
|
||||
zb |= ((qf & 0x002) >> 1) << 15;
|
||||
zc |= ((qf & 0x004) >> 2) << 15;
|
||||
zd |= ((qf & 0x008) >> 3) << 15;
|
||||
ze |= ((qf & 0x010) >> 4) << 15;
|
||||
za |= ((qf & 0x020) >> 5) << 31;
|
||||
zb |= ((qf & 0x040) >> 6) << 31;
|
||||
zc |= ((qf & 0x080) >> 7) << 31;
|
||||
zd |= ((qf & 0x100) >> 8) << 31;
|
||||
ze |= ((qf & 0x200) >> 9) << 31;
|
||||
|
||||
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// zb: vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// zc: vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// ze: vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
q[3 * stride] = zd;
|
||||
q[4 * stride] = ze;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y32_ = __float2half_rn(1.0f / 32.0f);
|
||||
const half2 y32 = __halves2half2(y32_, y32_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
|
||||
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z32 = __halves2half2(z32_, z32_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
uint32_t qd = q_3;
|
||||
uint32_t qe = q_4;
|
||||
|
||||
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
|
||||
qa >>= 10;
|
||||
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
qa >>= 5;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
|
||||
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
|
||||
qb >>= 10;
|
||||
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
|
||||
qb >>= 4;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
|
||||
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
|
||||
qc >>= 10;
|
||||
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
|
||||
qc >>= 3;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
|
||||
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
|
||||
qd >>= 10;
|
||||
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
|
||||
qd >>= 2;
|
||||
qd &= 0x00080008;
|
||||
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
|
||||
qe >>= 10;
|
||||
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
|
||||
qe >>= 1;
|
||||
qe &= 0x00100010;
|
||||
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hadd2( q3.as_half2, z1);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hadd2( q6.as_half2, z1);
|
||||
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
|
||||
dq[ 8] = __hadd2( q8.as_half2, z1);
|
||||
dq[ 9] = __hadd2( q9.as_half2, z1);
|
||||
dq[10] = __hfma2(q10.as_half2, y32, z32);
|
||||
dq[11] = __hadd2(q11.as_half2, z1);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y32, z32);
|
||||
dq[14] = __hadd2(q14.as_half2, z1);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
|
||||
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
|
||||
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
|
||||
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
|
||||
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,42 +0,0 @@
|
||||
#ifndef _qdq_6_cuh
|
||||
#define _qdq_6_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_6BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_6bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_6bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
|
||||
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
|
||||
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
|
||||
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
|
||||
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,38 +0,0 @@
|
||||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_8BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_8bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -1,53 +0,0 @@
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
__device__ half2_uint32() : as_uint32(0) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
__device__ half_uint16() : as_uint16(0) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
#endif
|
@ -1,54 +0,0 @@
|
||||
#ifndef _util_cuh
|
||||
#define _util_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
|
||||
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
|
||||
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
|
||||
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
|
||||
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
|
||||
|
||||
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
|
||||
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
|
||||
|
||||
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
|
||||
{
|
||||
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
|
||||
qs_h = __hmul(qs_h, qs_h);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float clamp(float x, float a, float b)
|
||||
{
|
||||
return fmaxf(a, fminf(b, x));
|
||||
}
|
||||
|
||||
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
|
||||
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
|
||||
{
|
||||
if (code != cudaSuccess)
|
||||
{
|
||||
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
|
||||
if (abort) exit(code);
|
||||
}
|
||||
}
|
||||
|
||||
void print_global_mem(const half* ptr, int rows, int columns, int stride);
|
||||
|
||||
#endif
|
@ -1,139 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#include "cuda/q_matrix.cuh"
|
||||
#include "cuda/q_gemm.cuh"
|
||||
|
||||
#include "cpp/util.h"
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
|
||||
|
||||
// Quant matrix
|
||||
|
||||
uintptr_t make_q_matrix
|
||||
(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
torch::Tensor q_invperm,
|
||||
torch::Tensor q_scale,
|
||||
torch::Tensor q_scale_max,
|
||||
torch::Tensor q_groups,
|
||||
torch::Tensor q_group_map,
|
||||
torch::Tensor gptq_qzeros,
|
||||
torch::Tensor gptq_scales,
|
||||
torch::Tensor gptq_g_idx,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(q_weight, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||
|
||||
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
||||
|
||||
int device = q_weight.device().index();
|
||||
int width = q_weight.size(1);
|
||||
int groups;
|
||||
int height;
|
||||
|
||||
if (!q_scale.device().is_meta())
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
|
||||
groups = q_scale.size(0);
|
||||
height = q_invperm.size(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
|
||||
groups = gptq_qzeros.size(0);
|
||||
height = q_weight.size(0) * 8;
|
||||
}
|
||||
|
||||
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
||||
|
||||
QMatrix* m = new QMatrix
|
||||
(
|
||||
device,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
|
||||
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
|
||||
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
|
||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
|
||||
if (m->failed) throw std::runtime_error("CUDA out of memory");
|
||||
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
void gemm_half_q_half
|
||||
(
|
||||
torch::Tensor a,
|
||||
uintptr_t b,
|
||||
torch::Tensor c,
|
||||
bool force_cuda
|
||||
)
|
||||
{
|
||||
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
||||
|
||||
TORCH_CHECK_DTYPE(a, kHalf);
|
||||
TORCH_CHECK_DTYPE(c, kHalf);
|
||||
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
|
||||
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
|
||||
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
|
||||
gemm_half_q_half_cuda
|
||||
(
|
||||
at::cuda::getCurrentCUDABlasHandle(),
|
||||
(const half*) a.data_ptr(),
|
||||
qm,
|
||||
(half*) c.data_ptr(),
|
||||
c.size(0), // m
|
||||
c.size(1), // n
|
||||
a.size(1), // k
|
||||
true,
|
||||
NULL,
|
||||
force_cuda
|
||||
);
|
||||
}
|
||||
|
||||
// Bindings
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
|
||||
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
setup(
|
||||
name="exllamav2_kernels",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="exllamav2_kernels",
|
||||
sources=[
|
||||
"exllamav2_kernels/ext.cpp",
|
||||
"exllamav2_kernels/cuda/q_matrix.cu",
|
||||
"exllamav2_kernels/cuda/q_gemm.cu",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
@ -1,54 +0,0 @@
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
@ -1,54 +0,0 @@
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
@ -1,54 +0,0 @@
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
@ -3,7 +3,7 @@ import os
|
||||
|
||||
from .common import Seqlen
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
from .cuda import (
|
||||
|
@ -728,7 +728,7 @@ class CausalLM(Model):
|
||||
self.enable_hpu_graph = (
|
||||
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||
)
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"
|
||||
|
||||
if model.config.model_type not in [
|
||||
"gpt_bigcode"
|
||||
@ -790,9 +790,9 @@ class CausalLM(Model):
|
||||
if model.config.model_type not in ["gpt_bigcode"]:
|
||||
self.kwargs["trim_logits"] = True
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
||||
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true":
|
||||
self.kwargs["use_flash_attention"] = True
|
||||
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true":
|
||||
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true":
|
||||
self.kwargs["flash_attention_recompute"] = True
|
||||
|
||||
self.speculate = get_speculate()
|
||||
|
@ -85,8 +85,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
token_idx: Optional[torch.Tensor] = None,
|
||||
use_flash_attention: Optional[bool] = False,
|
||||
flash_attention_recompute: Optional[bool] = False,
|
||||
use_flash_attention: Optional[bool] = True,
|
||||
flash_attention_recompute: Optional[bool] = True,
|
||||
):
|
||||
|
||||
if token_idx is not None:
|
||||
@ -156,8 +156,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
use_flash_attention = kwargs.get("use_flash_attention", False)
|
||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
|
||||
use_flash_attention = kwargs.get("use_flash_attention", True)
|
||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
labels = kwargs.get("labels", None)
|
||||
|
@ -621,7 +621,7 @@ class VlmCausalLM(Model):
|
||||
self.enable_hpu_graph = (
|
||||
os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||
)
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"
|
||||
model = remove_kv_cache_from_output(model)
|
||||
if self.enable_hpu_graph:
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
@ -668,9 +668,9 @@ class VlmCausalLM(Model):
|
||||
self.kwargs["attn_softmax_bf16"] = True
|
||||
self.kwargs["trim_logits"] = True
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
||||
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true":
|
||||
self.kwargs["use_flash_attention"] = True
|
||||
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true":
|
||||
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true":
|
||||
self.kwargs["flash_attention_recompute"] = True
|
||||
|
||||
self.speculate = get_speculate()
|
||||
|
@ -20,6 +20,9 @@ from text_generation_server.utils.tokens import (
|
||||
FinishReason,
|
||||
Sampling,
|
||||
Greedy,
|
||||
make_tokenizer_optional,
|
||||
is_tokenizer_transparent,
|
||||
pad_next_token_chooser_parameters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -41,4 +44,7 @@ __all__ = [
|
||||
"StopSequenceCriteria",
|
||||
"FinishReason",
|
||||
"Weights",
|
||||
"make_tokenizer_optional",
|
||||
"is_tokenizer_transparent",
|
||||
"pad_next_token_chooser_parameters",
|
||||
]
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Set, Union
|
||||
|
||||
@ -22,6 +21,7 @@ from text_generation_server.utils.logits_process import (
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
import os
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -753,8 +753,6 @@ def make_tokenizer_optional(tokenizer):
|
||||
# I don't think this method is used anywhere and should be removed when doing refactoring
|
||||
return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
|
||||
|
||||
import os
|
||||
|
||||
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
|
||||
tokenizer.__class__ = _
|
||||
tokenizer.is_transparent = True
|
||||
|
@ -1,7 +1,7 @@
|
||||
from optimum.habana.utils import get_driver_version
|
||||
from packaging.version import Version
|
||||
|
||||
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.16.0")
|
||||
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
|
||||
|
||||
|
||||
def is_driver_compatible():
|
||||
|
@ -2,4 +2,10 @@
|
||||
|
||||
ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'
|
||||
|
||||
# Check if --sharded argument is present in the command line arguments
|
||||
if [[ "$*" == *"--sharded true"* ]]; then
|
||||
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
|
||||
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
|
||||
fi
|
||||
|
||||
text-generation-launcher $@
|
||||
|
Loading…
Reference in New Issue
Block a user