mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Merge tag 'v1.2.0' into v1.2-release
This commit is contained in:
commit
e5f124b077
93
.github/workflows/build.yaml
vendored
93
.github/workflows/build.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
build-and-push-image:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
group: ${{ github.workflow }}-build-and-push-image-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
@ -146,6 +146,95 @@ jobs:
|
||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
||||
|
||||
build-and-push-image-rocm:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
# This is used to complete the identity challenge
|
||||
# with sigstore/fulcio when running outside of PRs.
|
||||
id-token: write
|
||||
security-events: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Initialize Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2.0.0
|
||||
with:
|
||||
install: true
|
||||
- name: Inject slug/short variables
|
||||
uses: rlespinasse/github-slug-action@v4.4.1
|
||||
- name: Tailscale
|
||||
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||
- name: Login to GitHub Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Login to internal Container Registry
|
||||
uses: docker/login-action@v2.1.0
|
||||
with:
|
||||
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
|
||||
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
|
||||
registry: registry.internal.huggingface.tech
|
||||
- name: Login to Azure Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v2.1.0
|
||||
with:
|
||||
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
||||
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
||||
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
||||
# If pull request
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
id: meta-pr
|
||||
uses: docker/metadata-action@v4.3.0
|
||||
with:
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
tags: |
|
||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
||||
# If main, release or tag
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4.3.0
|
||||
with:
|
||||
flavor: |
|
||||
latest=false
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
ghcr.io/huggingface/text-generation-inference
|
||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||
tags: |
|
||||
type=semver,pattern={{version}}-rocm
|
||||
type=semver,pattern={{major}}.{{minor}}-rocm
|
||||
type=raw,value=latest-rocm,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
||||
- name: Build and push Docker image
|
||||
id: build-and-push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile_amd
|
||||
push: true
|
||||
platforms: 'linux/amd64'
|
||||
build-args: |
|
||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||
|
||||
integration-tests:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
@ -153,6 +242,7 @@ jobs:
|
||||
needs:
|
||||
- start-runner
|
||||
- build-and-push-image # Wait for the docker image to be built
|
||||
- build-and-push-image-rocm
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
env:
|
||||
DOCKER_VOLUME: /cache
|
||||
@ -187,6 +277,7 @@ jobs:
|
||||
needs:
|
||||
- start-runner
|
||||
- build-and-push-image
|
||||
- build-and-push-image-rocm
|
||||
- integration-tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
|
888
Cargo.lock
generated
888
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "1.1.1"
|
||||
version = "1.2.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
153
Dockerfile_amd
Normal file
153
Dockerfile_amd
Normal file
@ -0,0 +1,153 @@
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-20.04:5.7 as base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
ccache \
|
||||
curl \
|
||||
git \
|
||||
make \
|
||||
libssl-dev \
|
||||
g++ \
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.2.0.dev0'
|
||||
ARG ROCM_VERSION='5.7'
|
||||
ARG PYTHON_VERSION='3.10.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
mamba init && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
||||
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
# Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-vllm Makefile
|
||||
|
||||
# Build specific version of vllm
|
||||
RUN make build-vllm-rocm
|
||||
|
||||
# Build Flash Attention v2 kernels
|
||||
FROM kernel-builder AS flash-att-v2-builder
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-flash-att-v2 Makefile
|
||||
|
||||
# Build specific version of flash attention v2
|
||||
RUN make build-flash-attention-v2-rocm
|
||||
|
||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/custom_kernels/ .
|
||||
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||
|
||||
FROM base as base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_rocm.txt && \
|
||||
pip install ".[accelerate, peft]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base-copy as sagemaker
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["./entrypoint.sh"]
|
||||
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
@ -58,11 +58,11 @@ For more information and documentation about Text Generation Inference, checkout
|
||||
Not all features of TGI are currently supported as this is still a work in progress.
|
||||
|
||||
New changes are added for the current release:
|
||||
- Sharded feature with support for DeepSpeed-inference auto tensor parallism. Also use HPU graph for performance improvement.
|
||||
- Sharded feature with support for DeepSpeed-inference auto tensor parallelism. Also, use HPU graphs for performance improvement.
|
||||
- Torch profile.
|
||||
|
||||
|
||||
Enviroment Variables Added:
|
||||
Environment Variables Added:
|
||||
|
||||
<div align="center">
|
||||
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "1.1.1"
|
||||
"version": "1.2.0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HUGGING_FACE_HUB_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.1 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.2 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -8,13 +8,15 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
|
||||
model=tiiuae/falcon-7b-instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.1 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.2 --model-id $model
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
||||
|
||||
To use TGI on RoCm-enabled AMD GPUs (only MI210 and MI250 are tested), please use the image `ghcr.io/huggingface/text-generation-inference:1.2+rocm` instead. For details about the usage on RoCm, please refer to the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html).
|
||||
|
||||
</Tip>
|
||||
|
||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||
@ -85,7 +87,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:1.1.1 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:1.2 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -39,9 +39,9 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
||||
|
||||
## Supported Hardware
|
||||
|
||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other hardware, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||
|
||||
TGI also has support of RoCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are missing from the RoCm version of TGI: quantization and flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm).
|
||||
|
||||
TGI is also supported on the following AI hardware accelerators:
|
||||
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
||||
|
||||
|
||||
|
@ -24,6 +24,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
||||
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
rtol = 0.2
|
||||
def serialize(
|
||||
self,
|
||||
data,
|
||||
@ -58,7 +59,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
return (
|
||||
token.id == other.id
|
||||
and token.text == other.text
|
||||
and math.isclose(token.logprob, other.logprob, rel_tol=0.2)
|
||||
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
|
||||
and token.special == other.special
|
||||
)
|
||||
|
||||
@ -68,7 +69,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
prefill_token.id == other.id
|
||||
and prefill_token.text == other.text
|
||||
and (
|
||||
math.isclose(prefill_token.logprob, other.logprob, rel_tol=0.2)
|
||||
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol)
|
||||
if prefill_token.logprob is not None
|
||||
else prefill_token.logprob == other.logprob
|
||||
)
|
||||
@ -148,6 +149,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
)
|
||||
|
||||
|
||||
class GenerousResponseComparator(ResponseComparator):
|
||||
# Needed for GPTQ with exllama which has serious numerical fluctuations.
|
||||
rtol = 0.75
|
||||
|
||||
class LauncherHandle:
|
||||
def __init__(self, port: int):
|
||||
self.client = AsyncClient(f"http://localhost:{port}")
|
||||
@ -193,6 +198,10 @@ class ProcessLauncherHandle(LauncherHandle):
|
||||
def response_snapshot(snapshot):
|
||||
return snapshot.use_extension(ResponseComparator)
|
||||
|
||||
@pytest.fixture
|
||||
def generous_response_snapshot(snapshot):
|
||||
return snapshot.use_extension(GenerousResponseComparator)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
@ -210,6 +219,7 @@ def launcher(event_loop):
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_flash_attention: bool = True,
|
||||
dtype: Optional[str] = None
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
@ -237,6 +247,9 @@ def launcher(event_loop):
|
||||
if quantize is not None:
|
||||
args.append("--quantize")
|
||||
args.append(quantize)
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
@ -269,6 +282,7 @@ def launcher(event_loop):
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_flash_attention: bool = True,
|
||||
dtype: Optional[str] = None
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
@ -279,6 +293,9 @@ def launcher(event_loop):
|
||||
if quantize is not None:
|
||||
args.append("--quantize")
|
||||
args.append(quantize)
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
@ -318,6 +335,7 @@ def launcher(event_loop):
|
||||
],
|
||||
volumes=volumes,
|
||||
ports={"80/tcp": port},
|
||||
shm_size="1G"
|
||||
)
|
||||
|
||||
yield ContainerLauncherHandle(client, container.name, port)
|
||||
|
@ -15,20 +15,20 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
|
||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||
response = await flash_starcoder_gptq.generate(
|
||||
"def geometric_mean(L: List[float]):",
|
||||
max_new_tokens=20,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
assert response.details.generated_tokens == 20
|
||||
assert response == response_snapshot
|
||||
assert response == generous_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq_default_params(
|
||||
flash_starcoder_gptq, response_snapshot
|
||||
flash_starcoder_gptq, generous_response_snapshot
|
||||
):
|
||||
response = await flash_starcoder_gptq.generate(
|
||||
"def geometric_mean(L: List[float]):",
|
||||
@ -39,13 +39,13 @@ async def test_flash_starcoder_gptq_default_params(
|
||||
seed=0,
|
||||
)
|
||||
assert response.details.generated_tokens == 20
|
||||
assert response == response_snapshot
|
||||
assert response == generous_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq_load(
|
||||
flash_starcoder_gptq, generate_load, response_snapshot
|
||||
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_starcoder_gptq,
|
||||
@ -57,4 +57,4 @@ async def test_flash_starcoder_gptq_load(
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
assert responses == generous_response_snapshot
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def idefics_handle(launcher):
|
||||
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2) as handle:
|
||||
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-integration-tests"
|
||||
version = "1.1.1"
|
||||
version = "1.2.0"
|
||||
description = "Text Generation Inference integration tests"
|
||||
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
||||
|
||||
|
@ -82,7 +82,7 @@ impl Infer {
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip(self))]
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
@ -133,7 +133,7 @@ impl Infer {
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
#[instrument(skip(self))]
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
@ -214,7 +214,7 @@ impl Infer {
|
||||
}
|
||||
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
||||
/// the highest log probability per token
|
||||
#[instrument(skip(self))]
|
||||
#[instrument(skip(self, request))]
|
||||
pub(crate) async fn generate_best_of(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
|
@ -300,15 +300,21 @@ mod tests {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||
if !std::path::Path::new("tokenizer.json").exists() {
|
||||
let filename = std::path::Path::new("tokenizer.json");
|
||||
if !filename.exists() {
|
||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
||||
.await
|
||||
.unwrap()
|
||||
.bytes()
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||
let tmp_filename = "tokenizer.json.temp";
|
||||
let mut file = std::fs::File::create(tmp_filename).unwrap();
|
||||
file.write_all(&content).unwrap();
|
||||
// Re-check if another process has written this file maybe.
|
||||
if !filename.exists() {
|
||||
std::fs::rename(tmp_filename, filename).unwrap()
|
||||
}
|
||||
}
|
||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ impl Validation {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
#[instrument(skip(self, inputs))]
|
||||
async fn validate_input(
|
||||
&self,
|
||||
inputs: String,
|
||||
|
@ -18,8 +18,8 @@ gen-server:
|
||||
|
||||
install: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements.txt
|
||||
pip install -e ".[bnb, accelerate]"
|
||||
pip install -r requirements_cuda.txt
|
||||
pip install -e ".[bnb, accelerate, quantize, peft]"
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
@ -2,7 +2,7 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||
|
||||
flash-attention:
|
||||
# Clone flash attention
|
||||
pip install packaging
|
||||
pip install -U packaging ninja --no-cache-dir
|
||||
git clone https://github.com/HazyResearch/flash-attention.git
|
||||
|
||||
build-flash-attention: flash-attention
|
||||
|
@ -1,13 +1,29 @@
|
||||
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
||||
|
||||
flash-attention-v2:
|
||||
|
||||
flash-attention-v2-cuda:
|
||||
# Clone flash attention
|
||||
pip install packaging
|
||||
pip install -U packaging ninja --no-cache-dir
|
||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||
|
||||
build-flash-attention-v2: flash-attention-v2
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
||||
build-flash-attention-v2-cuda: flash-attention-v2-cuda
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
|
||||
cd flash-attention-v2 && git submodule update --init --recursive
|
||||
cd flash-attention-v2 && python setup.py build
|
||||
|
||||
install-flash-attention-v2: build-flash-attention-v2
|
||||
cd flash-attention-v2 && python setup.py install
|
||||
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
||||
|
||||
flash-attention-v2-rocm:
|
||||
# Clone flash attention
|
||||
pip install -U packaging ninja --no-cache-dir
|
||||
git clone https://github.com/fxmarty/flash-attention-rocm flash-attention-v2
|
||||
|
||||
build-flash-attention-v2-rocm: flash-attention-v2-rocm
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
|
||||
cd flash-attention-v2 && git submodule update --init --recursive
|
||||
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||
|
||||
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
||||
|
@ -1,11 +1,20 @@
|
||||
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
|
||||
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
build-vllm-cuda: BRANCH=main
|
||||
build-vllm-cuda: build-vllm
|
||||
|
||||
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
||||
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
|
||||
build-vllm-rocm: build-vllm
|
||||
|
||||
vllm:
|
||||
# Clone vllm
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
||||
|
||||
build-vllm: vllm
|
||||
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm: build-vllm
|
||||
|
@ -1,5 +1,10 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_compile_args = ["-std=c++17"]
|
||||
if not torch.version.hip:
|
||||
extra_compile_args.append("-arch=compute_80")
|
||||
|
||||
setup(
|
||||
name="custom_kernels",
|
||||
@ -7,12 +12,12 @@ setup(
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_bloom_attention_cuda",
|
||||
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
extra_compile_args=extra_compile_args,
|
||||
),
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_attention_cuda",
|
||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
extra_compile_args=extra_compile_args,
|
||||
),
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
|
13
server/exllamav2_kernels/exllamav2_kernels/config.h
Normal file
13
server/exllamav2_kernels/exllamav2_kernels/config.h
Normal file
@ -0,0 +1,13 @@
|
||||
#ifndef _config_h
|
||||
#define _config_h
|
||||
|
||||
#define MAX_Q_GEMM_ROWS 50
|
||||
|
||||
#define QMODE_2BIT 1
|
||||
#define QMODE_3BIT 1
|
||||
#define QMODE_4BIT 1
|
||||
#define QMODE_5BIT 1
|
||||
#define QMODE_6BIT 0
|
||||
#define QMODE_8BIT 0
|
||||
|
||||
#endif
|
12
server/exllamav2_kernels/exllamav2_kernels/cpp/util.h
Normal file
12
server/exllamav2_kernels/exllamav2_kernels/cpp/util.h
Normal file
@ -0,0 +1,12 @@
|
||||
#ifndef _util_h
|
||||
#define _util_h
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
|
||||
#endif
|
56
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
Normal file
56
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
Normal file
@ -0,0 +1,56 @@
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
@ -0,0 +1,38 @@
|
||||
#ifndef _compat_gemm_cuh
|
||||
#define _compat_gemm_cuh
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
|
||||
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
|
||||
// for symbols as hipblasHalf.
|
||||
#include <hipblas/hipblas.h>
|
||||
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
#endif
|
||||
|
||||
#endif
|
121
server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh
Normal file
121
server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh
Normal file
@ -0,0 +1,121 @@
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "quant/qdq_util.cuh"
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
211
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
Normal file
211
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
Normal file
@ -0,0 +1,211 @@
|
||||
#include "q_gemm.cuh"
|
||||
#include "util.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "../config.h"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
#define BLOCK_M_SIZE_MAX 8
|
||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||
#define CLEAR_N_SIZE 256
|
||||
|
||||
#include "q_gemm_kernel.cuh"
|
||||
#include "q_gemm_kernel_gptq.cuh"
|
||||
|
||||
#include "compat_gemm.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int m_count,
|
||||
bool clear
|
||||
)
|
||||
{
|
||||
if (!b->is_gptq)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_q_scale,
|
||||
b->cuda_q_scale_max,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_8,
|
||||
b->rows_6,
|
||||
b->rows_5,
|
||||
b->rows_4,
|
||||
b->rows_3,
|
||||
b->rows_2,
|
||||
clear
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||
|
||||
// DBGX((uint64_t) b->cuda_q_perm);
|
||||
// DBGI(b->rows_4);
|
||||
// DBGI(b->height);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_gptq_qzeros,
|
||||
b->cuda_gptq_scales,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_4,
|
||||
clear
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear,
|
||||
half* temp_dq,
|
||||
bool force_cuda
|
||||
)
|
||||
{
|
||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||
{
|
||||
//printf("cublas\n");
|
||||
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
|
||||
if (!temp_dq) temp_dq = b->temp_dq;
|
||||
b->reconstruct(temp_dq);
|
||||
|
||||
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
|
||||
cublasHgemm(cublas_handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
size_n, size_m, size_k,
|
||||
&alpha, temp_dq, size_n,
|
||||
a, size_k,
|
||||
&beta, c, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasSgemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N,
|
||||
// CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasGemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n,
|
||||
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
|
||||
}
|
||||
else
|
||||
{
|
||||
//printf("cuda\n");
|
||||
|
||||
// Quantized matmul
|
||||
|
||||
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
||||
|
||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||
int last_chunk_size = size_m - last_chunk;
|
||||
|
||||
if (max_chunks)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
||||
}
|
||||
|
||||
if (last_chunk_size)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void clear_kernel
|
||||
(
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n
|
||||
)
|
||||
{
|
||||
int m = blockIdx.y;
|
||||
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
|
||||
if (n >= size_n) return;
|
||||
int4* c_ptr = (int4*)(c + m * size_n + n);
|
||||
*c_ptr = {};
|
||||
}
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
)
|
||||
{
|
||||
return;
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = CLEAR_N_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||
gridDim.y = size_m;
|
||||
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||
}
|
33
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh
Normal file
33
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh
Normal file
@ -0,0 +1,33 @@
|
||||
#ifndef _q_gemm_cuh
|
||||
#define _q_gemm_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q_matrix.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear = false,
|
||||
half* reconstruct = NULL,
|
||||
bool force_cuda = false
|
||||
);
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
);
|
||||
|
||||
#endif
|
@ -0,0 +1,487 @@
|
||||
#include "compat.cuh"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const bool
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2,
|
||||
const bool clear
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
|
||||
// Preload scales
|
||||
|
||||
float scales[MAX_GROUPS_IN_BLOCK][4];
|
||||
|
||||
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||
for (int g = 0; g < groups_in_block; g++)
|
||||
{
|
||||
int qscales[4];
|
||||
b_q_scale_.item4(qscales, group + g, n);
|
||||
qscales[0]++;
|
||||
qscales[1]++;
|
||||
qscales[2]++;
|
||||
qscales[3]++;
|
||||
float maxscale = __half2float(b_q_scale_max[group + g]);
|
||||
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
||||
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
||||
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
||||
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
||||
}
|
||||
|
||||
// a, b offset
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int scales_idx = 0;
|
||||
float qs_f0 = scales[scales_idx][0];
|
||||
float qs_f1 = scales[scales_idx][1];
|
||||
float qs_f2 = scales[scales_idx][2];
|
||||
float qs_f3 = scales[scales_idx][3];
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize groups
|
||||
|
||||
int k = offset_k;
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[2];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
|
||||
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
|
||||
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
|
||||
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[5];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
|
||||
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
|
||||
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
|
||||
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
|
||||
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
|
||||
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
|
||||
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
|
||||
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
|
||||
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
|
||||
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
// Accumulate column sums in c
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
@ -0,0 +1,219 @@
|
||||
#include "compat.cuh"
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hadd2(result, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
}
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const bool
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_gptq_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_4,
|
||||
const bool clear
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
|
||||
half a0;
|
||||
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
else a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Zero output
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// a, b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int zeros[4];
|
||||
float scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
// __syncthreads();
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize and multiply
|
||||
|
||||
int k = offset_k;
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
}
|
||||
|
||||
b_ptr += size_n;
|
||||
a_ptr += 8;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
623
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
Normal file
623
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
Normal file
@ -0,0 +1,623 @@
|
||||
#include "q_matrix.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "util.cuh"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
|
||||
#define THREADS_X 32
|
||||
#define THREADS_Y 32
|
||||
|
||||
// Shuffle quantized data on load
|
||||
|
||||
__global__ void shuffle_kernel
|
||||
(
|
||||
uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
|
||||
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
|
||||
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
|
||||
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
|
||||
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
|
||||
}
|
||||
|
||||
|
||||
// QMatrix constructor
|
||||
|
||||
QMatrix::QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
failed = false;
|
||||
|
||||
cuda_q_weight = _q_weight;
|
||||
cuda_q_perm = _q_perm;
|
||||
cuda_q_invperm = _q_invperm;
|
||||
cuda_q_scale = _q_scale;
|
||||
cuda_q_scale_max = _q_scale_max;
|
||||
cuda_q_groups = _q_groups;
|
||||
cuda_gptq_qzeros = _gptq_qzeros;
|
||||
cuda_gptq_scales = _gptq_scales;
|
||||
|
||||
is_gptq = (_gptq_qzeros != NULL);
|
||||
|
||||
groupsize = 1;
|
||||
while (groupsize * groups < height) groupsize *= 2;
|
||||
|
||||
// Create group map
|
||||
|
||||
rows_8 = 0;
|
||||
rows_6 = 0;
|
||||
rows_5 = 0;
|
||||
rows_4 = 0;
|
||||
rows_3 = 0;
|
||||
rows_2 = 0;
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
for (int i = 0; i < groups; i++)
|
||||
{
|
||||
int bits = cpu_q_groups[i * 2];
|
||||
if (bits == 8) rows_8 += groupsize;
|
||||
if (bits == 6) rows_6 += groupsize;
|
||||
if (bits == 5) rows_5 += groupsize;
|
||||
if (bits == 4) rows_4 += groupsize;
|
||||
if (bits == 3) rows_3 += groupsize;
|
||||
if (bits == 2) rows_2 += groupsize;
|
||||
}
|
||||
|
||||
free(cpu_q_groups);
|
||||
|
||||
rows_6 += rows_8;
|
||||
rows_5 += rows_6;
|
||||
rows_4 += rows_5;
|
||||
rows_3 += rows_4;
|
||||
rows_2 += rows_3;
|
||||
}
|
||||
else
|
||||
{
|
||||
rows_4 = height;
|
||||
rows_3 = height;
|
||||
rows_2 = height;
|
||||
|
||||
if (_gptq_g_idx)
|
||||
{
|
||||
if (!make_sequential(_gptq_g_idx))
|
||||
{
|
||||
failed = true;
|
||||
//printf("FAIL\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shuffle quantized data
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = 1;
|
||||
|
||||
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||
}
|
||||
|
||||
QMatrix::~QMatrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Reconstruct b[k,n] (GPTQ)
|
||||
|
||||
__global__ void reconstruct_gptq_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_4
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
|
||||
if (b_q_perm)
|
||||
{
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
}
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
// Initial zeros/scale
|
||||
|
||||
int zeros[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4][4];
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
b_ptr += size_n;
|
||||
//half* dqh = (half*)dq;
|
||||
if (b_q_perm)
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct b[k,n]
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
int t = threadIdx.x;
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
dequant_8bit_8(q_0, q_1, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_3 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_4 = *b_ptr; b_ptr += size_n;
|
||||
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_4bit_8(q_0, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_2bit_16(q_0, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
void QMatrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
reconstruct_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_q_scale,
|
||||
cuda_q_scale_max,
|
||||
//cuda_q_groups,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_8,
|
||||
rows_6,
|
||||
rows_5,
|
||||
rows_4,
|
||||
rows_3,
|
||||
rows_2
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_gptq_qzeros,
|
||||
cuda_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_4
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint16_t* __restrict__ q_perm,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = q_perm[q_perm_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
if (err != cudaSuccess) {
|
||||
cudaError_t cuda_status = cudaGetLastError(); // Clear error
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Reduce to uint16_t
|
||||
|
||||
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
|
||||
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
|
||||
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
|
||||
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = height / 8;
|
||||
|
||||
make_sequential_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_new_qweight,
|
||||
cuda_q_perm,
|
||||
height / 8,
|
||||
width
|
||||
);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
|
||||
return true;
|
||||
}
|
73
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh
Normal file
73
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh
Normal file
@ -0,0 +1,73 @@
|
||||
#ifndef _q_matrix_cuh
|
||||
#define _q_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#define MAX_SUPERGROUPS 16
|
||||
|
||||
class QMatrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
bool is_gptq;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
int rows_8;
|
||||
int rows_6;
|
||||
int rows_5;
|
||||
int rows_4;
|
||||
int rows_3;
|
||||
int rows_2;
|
||||
|
||||
uint32_t* cuda_q_weight = NULL;
|
||||
uint16_t* cuda_q_perm = NULL;
|
||||
uint16_t* cuda_q_invperm = NULL;
|
||||
uint32_t* cuda_q_scale = NULL;
|
||||
half* cuda_q_scale_max = NULL;
|
||||
uint16_t* cuda_q_groups = NULL;
|
||||
uint32_t* cuda_gptq_qzeros = NULL;
|
||||
half* cuda_gptq_scales = NULL;
|
||||
|
||||
half* temp_dq;
|
||||
|
||||
bool failed;
|
||||
|
||||
QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
~QMatrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
bool make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
private:
|
||||
|
||||
};
|
||||
|
||||
#endif
|
103
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh
Normal file
103
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh
Normal file
@ -0,0 +1,103 @@
|
||||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_2BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
|
||||
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z4 = __halves2half2(z4_, z4_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
169
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh
Normal file
169
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh
Normal file
@ -0,0 +1,169 @@
|
||||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_3BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
|
||||
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
|
||||
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
|
||||
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
227
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh
Normal file
227
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh
Normal file
@ -0,0 +1,227 @@
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_4BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z = __hmul(z, scale);
|
||||
z1[0] = __half2half2(z);
|
||||
y1[0] = __half2half2(scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1)[2],
|
||||
half2(&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z1[0] = __half2half2(z);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
half2 dqh2[8];
|
||||
|
||||
uint32_t qa = q_0;
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
dqh2[i] = __halves2half2(d0, d1);
|
||||
}
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
207
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh
Normal file
207
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh
Normal file
@ -0,0 +1,207 @@
|
||||
#ifndef _qdq_5_cuh
|
||||
#define _qdq_5_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_5BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
uint32_t qd = q[3 * stride];
|
||||
uint32_t qe = q[4 * stride];
|
||||
|
||||
// qa: 66555554 44443333 32222211 11100000
|
||||
// qb: ccccbbbb baaaaa99 99988888 77777666
|
||||
// qc: jiiiiihh hhhggggg fffffeee eedddddc
|
||||
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
|
||||
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
|
||||
|
||||
uint32_t qf = qe >> 22;
|
||||
qe <<= 8;
|
||||
qe |= qd >> 24;
|
||||
qd <<= 6;
|
||||
qd |= qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: 555554 44443333 32222211 11100000
|
||||
// qb: bbbbba aaaa9999 98888877 77766666
|
||||
// qc: hhhhhg ggggffff feeeeedd dddccccc
|
||||
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
|
||||
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
uint32_t zd = 0;
|
||||
uint32_t ze = 0;
|
||||
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
|
||||
|
||||
// za: 5555533 33311111 4444422 22200000
|
||||
// zb: bbbbb99 99977777 aaaaa88 88866666
|
||||
// zc: hhhhhff fffddddd gggggee eeeccccc
|
||||
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
|
||||
// ze: tttttrr rrrppppp sssssqq qqqooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
za |= ((qf & 0x001) >> 0) << 15;
|
||||
zb |= ((qf & 0x002) >> 1) << 15;
|
||||
zc |= ((qf & 0x004) >> 2) << 15;
|
||||
zd |= ((qf & 0x008) >> 3) << 15;
|
||||
ze |= ((qf & 0x010) >> 4) << 15;
|
||||
za |= ((qf & 0x020) >> 5) << 31;
|
||||
zb |= ((qf & 0x040) >> 6) << 31;
|
||||
zc |= ((qf & 0x080) >> 7) << 31;
|
||||
zd |= ((qf & 0x100) >> 8) << 31;
|
||||
ze |= ((qf & 0x200) >> 9) << 31;
|
||||
|
||||
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// zb: vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// zc: vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// ze: vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
q[3 * stride] = zd;
|
||||
q[4 * stride] = ze;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y32_ = __float2half_rn(1.0f / 32.0f);
|
||||
const half2 y32 = __halves2half2(y32_, y32_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
|
||||
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z32 = __halves2half2(z32_, z32_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
uint32_t qd = q_3;
|
||||
uint32_t qe = q_4;
|
||||
|
||||
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
|
||||
qa >>= 10;
|
||||
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
qa >>= 5;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
|
||||
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
|
||||
qb >>= 10;
|
||||
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
|
||||
qb >>= 4;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
|
||||
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
|
||||
qc >>= 10;
|
||||
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
|
||||
qc >>= 3;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
|
||||
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
|
||||
qd >>= 10;
|
||||
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
|
||||
qd >>= 2;
|
||||
qd &= 0x00080008;
|
||||
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
|
||||
qe >>= 10;
|
||||
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
|
||||
qe >>= 1;
|
||||
qe &= 0x00100010;
|
||||
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hadd2( q3.as_half2, z1);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hadd2( q6.as_half2, z1);
|
||||
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
|
||||
dq[ 8] = __hadd2( q8.as_half2, z1);
|
||||
dq[ 9] = __hadd2( q9.as_half2, z1);
|
||||
dq[10] = __hfma2(q10.as_half2, y32, z32);
|
||||
dq[11] = __hadd2(q11.as_half2, z1);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y32, z32);
|
||||
dq[14] = __hadd2(q14.as_half2, z1);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
|
||||
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
|
||||
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
|
||||
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
|
||||
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -0,0 +1,44 @@
|
||||
#ifndef _qdq_6_cuh
|
||||
#define _qdq_6_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_6BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_6bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_6bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
|
||||
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
|
||||
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
|
||||
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
|
||||
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
|
@ -0,0 +1,38 @@
|
||||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_8BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_8bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
@ -0,0 +1,51 @@
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
#endif
|
42
server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh
Normal file
42
server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh
Normal file
@ -0,0 +1,42 @@
|
||||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
|
||||
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
|
||||
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
|
||||
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
|
||||
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
|
||||
|
||||
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
|
||||
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
|
||||
|
||||
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
|
||||
{
|
||||
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
|
||||
qs_h = __hmul(qs_h, qs_h);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float clamp(float x, float a, float b)
|
||||
{
|
||||
return fmaxf(a, fminf(b, x));
|
||||
}
|
||||
|
||||
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
|
||||
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
|
||||
{
|
||||
if (code != cudaSuccess)
|
||||
{
|
||||
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
|
||||
if (abort) exit(code);
|
||||
}
|
||||
}
|
134
server/exllamav2_kernels/exllamav2_kernels/ext.cpp
Normal file
134
server/exllamav2_kernels/exllamav2_kernels/ext.cpp
Normal file
@ -0,0 +1,134 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#include "cuda/q_matrix.cuh"
|
||||
#include "cuda/q_gemm.cuh"
|
||||
|
||||
#include "cpp/util.h"
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
|
||||
|
||||
// Quant matrix
|
||||
|
||||
uintptr_t make_q_matrix
|
||||
(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
torch::Tensor q_invperm,
|
||||
torch::Tensor q_scale,
|
||||
torch::Tensor q_scale_max,
|
||||
torch::Tensor q_groups,
|
||||
torch::Tensor gptq_qzeros,
|
||||
torch::Tensor gptq_scales,
|
||||
torch::Tensor gptq_g_idx,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(q_weight, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||
|
||||
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
||||
|
||||
int device = q_weight.device().index();
|
||||
int width = q_weight.size(1);
|
||||
int groups;
|
||||
int height;
|
||||
|
||||
if (!q_scale.device().is_meta())
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
|
||||
groups = q_scale.size(0);
|
||||
height = q_invperm.size(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
|
||||
groups = gptq_qzeros.size(0);
|
||||
height = q_weight.size(0) * 8;
|
||||
}
|
||||
|
||||
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
||||
|
||||
QMatrix* m = new QMatrix
|
||||
(
|
||||
device,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
|
||||
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
|
||||
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
void gemm_half_q_half
|
||||
(
|
||||
torch::Tensor a,
|
||||
uintptr_t b,
|
||||
torch::Tensor c,
|
||||
bool force_cuda
|
||||
)
|
||||
{
|
||||
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
||||
|
||||
TORCH_CHECK_DTYPE(a, kHalf);
|
||||
TORCH_CHECK_DTYPE(c, kHalf);
|
||||
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
|
||||
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
|
||||
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
|
||||
gemm_half_q_half_cuda
|
||||
(
|
||||
at::cuda::getCurrentCUDABlasHandle(),
|
||||
(const half*) a.data_ptr(),
|
||||
qm,
|
||||
(half*) c.data_ptr(),
|
||||
c.size(0), // m
|
||||
c.size(1), // n
|
||||
a.size(1), // k
|
||||
true,
|
||||
NULL,
|
||||
force_cuda
|
||||
);
|
||||
}
|
||||
|
||||
// Bindings
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
|
||||
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
|
||||
}
|
17
server/exllamav2_kernels/setup.py
Normal file
17
server/exllamav2_kernels/setup.py
Normal file
@ -0,0 +1,17 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name="exllamav2_kernels",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="exllamav2_kernels",
|
||||
sources=[
|
||||
"exllamav2_kernels/ext.cpp",
|
||||
"exllamav2_kernels/cuda/q_matrix.cu",
|
||||
"exllamav2_kernels/cuda/q_gemm.cu",
|
||||
],
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
1930
server/poetry.lock
generated
1930
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "1.1.1"
|
||||
version = "1.2.0"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
|
46
server/requirements_common.txt
Normal file
46
server/requirements_common.txt
Normal file
@ -0,0 +1,46 @@
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
47
server/requirements_cuda.txt
Normal file
47
server/requirements_cuda.txt
Normal file
@ -0,0 +1,47 @@
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
bitsandbytes==0.41.2.post2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
46
server/requirements_rocm.txt
Normal file
46
server/requirements_rocm.txt
Normal file
@ -0,0 +1,46 @@
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
@ -137,6 +137,17 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
else:
|
||||
# Try to load as a local PEFT model
|
||||
try:
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
# Try to see if there are local pytorch weights
|
||||
try:
|
||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||
|
@ -26,9 +26,6 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -38,6 +35,12 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
|
||||
|
||||
class LlamaConfig(PretrainedConfig):
|
||||
@ -120,7 +123,7 @@ class LlamaRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
@ -143,6 +146,22 @@ class LlamaRMSNorm(nn.Module):
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out, residual
|
||||
else:
|
||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
@ -204,9 +223,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
@ -262,8 +278,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
|
@ -26,11 +26,8 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -39,8 +36,14 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if not HAS_FLASH_ATTN_V2:
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
|
||||
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
|
||||
raise ImportError("Mistral model requires flash attn v2")
|
||||
|
||||
|
||||
@ -126,7 +129,7 @@ class MistralRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
@ -149,6 +152,22 @@ class MistralRMSNorm(nn.Module):
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out, residual
|
||||
else:
|
||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
@ -261,8 +280,7 @@ class MistralAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
|
@ -135,8 +135,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||
|
@ -185,8 +185,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
@ -301,8 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, :, 0].contiguous(),
|
||||
|
@ -55,8 +55,12 @@ from text_generation_server.utils.layers import (
|
||||
PositionRotaryEmbedding,
|
||||
FastLinear,
|
||||
)
|
||||
import dropout_layer_norm
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||
@ -370,7 +374,7 @@ class IdeficsRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
else:
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
@ -402,6 +406,32 @@ class IdeficsRMSNorm(nn.Module):
|
||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||
|
||||
return normed_hidden_states
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
unwrap = True
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
|
||||
if unwrap:
|
||||
out = out.view(*shape)
|
||||
|
||||
return out
|
||||
else:
|
||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||
|
||||
|
||||
# this was adapted from LlamaMLP
|
||||
@ -581,15 +611,12 @@ class IdeficsAttention(nn.Module):
|
||||
position_ids.view(-1), max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
shape = query_states.shape
|
||||
query_states = self.rotary_emb(
|
||||
query_states.view(-1, *shape[2:]), cos, sin
|
||||
).view(shape)
|
||||
query_shape = query_states.shape
|
||||
key_shape = key_states.shape
|
||||
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin)
|
||||
|
||||
shape = key_states.shape
|
||||
key_states = self.rotary_emb(
|
||||
key_states.reshape(-1, *shape[2:]), cos, sin
|
||||
).view(shape)
|
||||
query_states = query_states.view(query_shape)
|
||||
key_states = key_states.view(key_shape)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
@ -583,7 +583,7 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
@ -3,6 +3,8 @@ import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
|
||||
@ -15,7 +17,8 @@ is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
HAS_FLASH_ATTN = False
|
||||
HAS_FLASH_ATTN_V2 = False
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
HAS_FLASH_ATTN_V2_ROCM = False
|
||||
try:
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
@ -30,7 +33,8 @@ try:
|
||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||
"Flash Attention V2"
|
||||
)
|
||||
HAS_FLASH_ATTN_V2 = True
|
||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||
except ImportError as e:
|
||||
try:
|
||||
import flash_attn_cuda
|
||||
@ -41,10 +45,17 @@ except ImportError as e:
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
|
||||
if not (is_sm75 or is_sm8x or is_sm90):
|
||||
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
) from e
|
||||
elif IS_ROCM_SYSTEM:
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||
raise ImportError(
|
||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||
)
|
||||
|
||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||
HAS_FLASH_ATTN = True
|
||||
|
||||
@ -59,7 +70,7 @@ def attention(
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if HAS_FLASH_ATTN_V2:
|
||||
if HAS_FLASH_ATTN_V2_CUDA:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
@ -78,8 +89,28 @@ def attention(
|
||||
False,
|
||||
None,
|
||||
)
|
||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||
if window_size_left != -1:
|
||||
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).")
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
elif HAS_FLASH_ATTN:
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
|
191
server/text_generation_server/utils/gptq/exllamav2.py
Normal file
191
server/text_generation_server/utils/gptq/exllamav2.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||
except ImportError:
|
||||
logger.error('exllamav2_kernels not installed.')
|
||||
raise
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
output_shape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
|
||||
gemm_half_q_half(x, q_handle, output, force_cuda)
|
||||
return output.view(output_shape)
|
||||
|
||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
"""
|
||||
Create Q matrix
|
||||
"""
|
||||
# EXL2
|
||||
# won't work as the moment because the tensors are not the same.
|
||||
if "q_weight" in w:
|
||||
w["q_scale_max"] /= 256
|
||||
w["q_perm"] = w["q_perm"].short()
|
||||
w["q_invperm"] = w["q_invperm"].short()
|
||||
return make_q_matrix(w["q_weight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
w["q_scale"],
|
||||
w["q_scale_max"],
|
||||
w["q_groups"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
temp_dq)
|
||||
# GPTQ
|
||||
elif "qweight" in w:
|
||||
if w["scales"].dtype == torch.float:
|
||||
w["scales"] = w["scales"].half()
|
||||
|
||||
# GPTQ with g_idx (act_order)
|
||||
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
|
||||
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
|
||||
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||
return make_q_matrix(w["qweight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
w["g_idx"].cpu(),
|
||||
temp_dq)
|
||||
# GPTQ without g_idx
|
||||
else:
|
||||
return make_q_matrix(w["qweight"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
none_tensor,
|
||||
temp_dq)
|
||||
|
||||
DEVICE = None
|
||||
FIXED_BYTES = 0
|
||||
LAYERS = []
|
||||
|
||||
|
||||
def set_device(device):
|
||||
global DEVICE
|
||||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers():
|
||||
global FIXED_BYTES, LAYERS, DEVICE
|
||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
|
||||
|
||||
for layer in LAYERS:
|
||||
layer.post_init(temp_dq)
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "exllamav2"
|
||||
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
# def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
if bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
||||
self.q_handle = None
|
||||
self.q_tensors = None
|
||||
self.bits = bits
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
self.infeatures = qweight.shape[0] // self.bits * 32
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.padding = - self.outfeatures % 32
|
||||
self.outfeatures = self.outfeatures + self.padding
|
||||
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx
|
||||
self.bias = bias if bias is not None else None
|
||||
self.group_size = groupsize
|
||||
|
||||
infeatures = self.infeatures
|
||||
outfeatures = self.outfeatures
|
||||
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
|
||||
assert infeatures % self.group_size == 0
|
||||
assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits)
|
||||
assert scales.shape == (infeatures // self.group_size, outfeatures)
|
||||
assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}"
|
||||
|
||||
global FIXED_BYTES, LAYERS
|
||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||
LAYERS.append(self)
|
||||
|
||||
def post_init(self, temp_dq):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
assert self.qweight.device.index is not None
|
||||
self.q_tensors = {
|
||||
"qweight":self.qweight,
|
||||
"qzeros":self.qzeros,
|
||||
"scales":self.scales,
|
||||
"g_idx":self.g_idx
|
||||
}
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||
self.q_handle = ext_make_q_matrix(
|
||||
self.q_tensors, temp_dq
|
||||
)
|
||||
|
||||
def forward(self, x, force_cuda = False):
|
||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||
|
||||
if self.bias is not None:
|
||||
output.add_(self.bias)
|
||||
return output
|
||||
|
||||
def temp_dq_size(self):
|
||||
return self.infeatures * self.outfeatures * 2 + 128
|
||||
|
||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||
|
||||
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
|
||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||
|
||||
|
||||
class ExLlamaV2DeviceTensors:
|
||||
|
||||
device_idx: int
|
||||
scratch_bytes: int
|
||||
scratch_idx: int
|
||||
scratch: torch.tensor = None
|
||||
|
||||
def __init__(self, device, scratch_bytes):
|
||||
self.device = device
|
||||
self.scratch_bytes = scratch_bytes
|
||||
|
||||
def prepare(self):
|
||||
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device)
|
||||
|
||||
def get_scratch_slice(self, size_bytes):
|
||||
|
||||
if self.scratch is None: self.prepare()
|
||||
|
||||
size_bytes = ((size_bytes + 127) // 128) * 128
|
||||
size_half = size_bytes // 2
|
||||
scratch_slice = self.scratch.narrow(0, 0, size_half)
|
||||
return scratch_slice
|
4
server/text_generation_server/utils/import_utils.py
Normal file
4
server/text_generation_server/utils/import_utils.py
Normal file
@ -0,0 +1,4 @@
|
||||
import torch
|
||||
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
@ -31,15 +30,31 @@ try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
major = 1
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||
logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding")
|
||||
V2 = False
|
||||
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
elif CAN_EXLLAMA:
|
||||
try:
|
||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||
if V2:
|
||||
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
HAS_EXLLAMA = "2"
|
||||
else:
|
||||
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
HAS_EXLLAMA = "1"
|
||||
|
||||
HAS_EXLLAMA = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -308,7 +323,7 @@ def get_linear(weight, bias, quantize):
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
else:
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
@ -509,11 +524,14 @@ class TensorParallelEmbedding(nn.Module):
|
||||
|
||||
|
||||
try:
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
@ -545,14 +563,16 @@ try:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
if IS_CUDA_SYSTEM:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
@ -581,6 +601,37 @@ try:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
# Such controlflows may add some overhead.
|
||||
if IS_CUDA_SYSTEM:
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = query[..., :rotary_dim]
|
||||
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
|
||||
k1 = key[..., :rotary_dim]
|
||||
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
head_size = query.shape[-1]
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos,
|
||||
sin,
|
||||
True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
@ -683,21 +734,19 @@ try:
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
if IS_ROCM_SYSTEM:
|
||||
# For RoCm, we always use float cos/sin to avoid a cast.
|
||||
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
||||
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||
dtype = torch.float32
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
x1 = x[..., :rotary_dim]
|
||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
|
@ -38,7 +38,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||
os.makedirs(model_id, exist_ok=True)
|
||||
cache_dir = model_id
|
||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code)
|
||||
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||
model.config.save_pretrained(cache_dir)
|
||||
tokenizer.save_pretrained(cache_dir)
|
||||
|
@ -278,23 +278,13 @@ class Weights:
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info("Using exllama kernels")
|
||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama:
|
||||
if groupsize >= 0:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered.
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
# For tp > 1, at this point we know we do not use act-order
|
||||
if self.process_group.size() == 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0)
|
||||
g_idx = g_idx - g_idx[0]
|
||||
else:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
|
Loading…
Reference in New Issue
Block a user