diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi new file mode 100644 index 000000000..e59814c30 --- /dev/null +++ b/Dockerfile_gaudi @@ -0,0 +1,105 @@ +# Rust builder +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef AS planner +COPY Cargo.lock Cargo.lock +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY backends backends +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3.11-dev +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json + +ARG GIT_SHA +ARG DOCKER_LABEL + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY backends backends +COPY launcher launcher +RUN cargo build --profile release-opt + +# Text Generation Inference base image +FROM vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest AS base + +ENV ATTENTION=default +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 + +# Text Generation Inference base env +ENV HF_HOME=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it +RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ + dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb + +WORKDIR /usr/src + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + libssl-dev \ + ca-certificates \ + make \ + curl \ + git \ + python3.11-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install --no-deps -r requirements.txt && \ + bash ./dill-0.3.8-patch.sh && \ + pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0 && \ + BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ + pip install . --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher + + +# AWS Sagemaker compatible image +FROM base AS sagemaker + +COPY sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] + +# Final image +FROM base + +COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh +RUN chmod +x /tgi-entrypoint.sh + +ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] \ No newline at end of file diff --git a/backends/gaudi/server/.gitignore b/backends/gaudi/server/.gitignore new file mode 100644 index 000000000..576746eec --- /dev/null +++ b/backends/gaudi/server/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +text_generation_server/__pycache__/ +text_generation_server/pb/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +transformers +safetensors +flash-attention/ +flash-attention-v2/ +vllm/ +llm-awq/ +eetq/ +mamba/ diff --git a/backends/gaudi/server/Makefile b/backends/gaudi/server/Makefile new file mode 100644 index 000000000..f01897e59 --- /dev/null +++ b/backends/gaudi/server/Makefile @@ -0,0 +1,36 @@ +include Makefile-flash-att +include Makefile-flash-att-v2 +include Makefile-vllm +include Makefile-awq +include Makefile-eetq +include Makefile-selective-scan + +unit-tests: + pytest -s -vv -m "not private" tests + +gen-server: + # Compile protos + pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir + mkdir text_generation_server/pb || true + python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \ + --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto + find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch text_generation_server/pb/__init__.py + +install: gen-server + pip install pip --upgrade + pip install --no-deps -r requirements.txt + pip install -e "." + +run-dev: + SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded + +install-poetry: + curl -sSL https://install.python-poetry.org | python3 - + +update-lock: + rm poetry.lock + poetry lock --no-update + +export-requirements: + poetry export -o requirements.txt --without-hashes diff --git a/backends/gaudi/server/Makefile-awq b/backends/gaudi/server/Makefile-awq new file mode 100644 index 000000000..4e074a133 --- /dev/null +++ b/backends/gaudi/server/Makefile-awq @@ -0,0 +1,15 @@ +# Fork that adds only the correct stream to this kernel in order +# to make cuda graphs work. +awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4 + +awq: + rm -rf llm-awq + git clone https://github.com/huggingface/llm-awq + +build-awq: awq + cd llm-awq/ && git fetch && git checkout $(awq_commit) + cd llm-awq/awq/kernels && python setup.py build + +install-awq: build-awq + pip uninstall awq_inference_engine -y || true + cd llm-awq/awq/kernels && python setup.py install diff --git a/backends/gaudi/server/Makefile-eetq b/backends/gaudi/server/Makefile-eetq new file mode 100644 index 000000000..726e47b57 --- /dev/null +++ b/backends/gaudi/server/Makefile-eetq @@ -0,0 +1,13 @@ +eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0 + +eetq: + # Clone eetq + pip install packaging + git clone https://github.com/NetEase-FuXi/EETQ.git eetq + +build-eetq: eetq + cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive + cd eetq && python setup.py build + +install-eetq: build-eetq + cd eetq && python setup.py install diff --git a/backends/gaudi/server/Makefile-exllamav2 b/backends/gaudi/server/Makefile-exllamav2 new file mode 100644 index 000000000..0d4cc385a --- /dev/null +++ b/backends/gaudi/server/Makefile-exllamav2 @@ -0,0 +1,12 @@ +exllamav2_commit := v0.1.8 + +build-exllamav2: + git clone https://github.com/turboderp/exllamav2.git exllamav2 && \ + cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build + +install-exllamav2: build-exllamav2 + cd exllamav2/ && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install diff --git a/backends/gaudi/server/Makefile-fbgemm b/backends/gaudi/server/Makefile-fbgemm new file mode 100644 index 000000000..3b8061a1f --- /dev/null +++ b/backends/gaudi/server/Makefile-fbgemm @@ -0,0 +1,15 @@ +fbgemm_commit := v0.8.0 + +build-fbgemm: + @if [ ! -d "fbgemm" ]; then \ + git clone https://github.com/pytorch/FBGEMM.git fbgemm; \ + fi + cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \ + git submodule update --init --recursive && \ + cd fbgemm_gpu && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build + +install-fbgemm: build-fbgemm + cd fbgemm/fbgemm_gpu && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/backends/gaudi/server/Makefile-flash-att b/backends/gaudi/server/Makefile-flash-att new file mode 100644 index 000000000..29e75bc48 --- /dev/null +++ b/backends/gaudi/server/Makefile-flash-att @@ -0,0 +1,12 @@ +flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec + +build-flash-attention: + if [ ! -d 'flash-attention' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/HazyResearch/flash-attention.git; \ + fi + cd flash-attention && git fetch && git checkout $(flash_att_commit) && \ + MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build + +install-flash-attention: build-flash-attention + cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install diff --git a/backends/gaudi/server/Makefile-flash-att-v2 b/backends/gaudi/server/Makefile-flash-att-v2 new file mode 100644 index 000000000..a9cdf7822 --- /dev/null +++ b/backends/gaudi/server/Makefile-flash-att-v2 @@ -0,0 +1,21 @@ +flash_att_v2_commit_cuda := v2.6.1 +flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 + +build-flash-attention-v2-cuda: + pip install -U packaging wheel + pip install flash-attn==$(flash_att_v2_commit_cuda) + +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + echo "Flash v2 installed" + +build-flash-attention-v2-rocm: + if [ ! -d 'flash-attention-v2' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \ + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ + git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + fi + +install-flash-attention-v2-rocm: build-flash-attention-v2-rocm + cd flash-attention-v2 && \ + GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install diff --git a/backends/gaudi/server/Makefile-flashinfer b/backends/gaudi/server/Makefile-flashinfer new file mode 100644 index 000000000..f0a27622a --- /dev/null +++ b/backends/gaudi/server/Makefile-flashinfer @@ -0,0 +1,2 @@ +install-flashinfer: + pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 diff --git a/backends/gaudi/server/Makefile-lorax-punica b/backends/gaudi/server/Makefile-lorax-punica new file mode 100644 index 000000000..72f06f763 --- /dev/null +++ b/backends/gaudi/server/Makefile-lorax-punica @@ -0,0 +1,12 @@ +lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc + +build-lorax-punica: + if [ ! -d 'lorax-punica' ]; then \ + git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ + fi + cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) + cd lorax-punica && git submodule update --init --recursive + cd lorax-punica/server/punica_kernels && python setup.py build + +install-lorax-punica: build-lorax-punica + cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/backends/gaudi/server/Makefile-selective-scan b/backends/gaudi/server/Makefile-selective-scan new file mode 100644 index 000000000..b93b517d6 --- /dev/null +++ b/backends/gaudi/server/Makefile-selective-scan @@ -0,0 +1,28 @@ +selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 + +causal-conv1d: + rm -rf causal-conv1d + git clone https://github.com/Dao-AILab/causal-conv1d.git + +build-causal-conv1d: causal-conv1d + cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag + cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build + +install-causal-conv1d: build-causal-conv1d + pip uninstall causal-conv1d -y || true + cd causal-conv1d/ && pip install . + +# selective-scan dependends on causal-conv1d +selective-scan: + rm -rf mamba + git clone https://github.com/state-spaces/mamba.git mamba + +build-selective-scan: selective-scan + cd mamba/ && git fetch && git checkout $(selective_scan_commit) + cd mamba && python setup.py build + +install-selective-scan: install-causal-conv1d build-selective-scan + pip uninstall selective-scan-cuda -y || true + cd mamba && pip install . + +build-all: build-causal-conv1d build-selective-scan diff --git a/backends/gaudi/server/Makefile-vllm b/backends/gaudi/server/Makefile-vllm new file mode 100644 index 000000000..18dcc4a0c --- /dev/null +++ b/backends/gaudi/server/Makefile-vllm @@ -0,0 +1,23 @@ +commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b +commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247 +build-vllm-cuda: + if [ ! -d 'vllm' ]; then \ + pip install -U ninja packaging --no-cache-dir && \ + git clone https://github.com/Narsil/vllm.git vllm; \ + fi + cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build + +install-vllm-cuda: build-vllm-cuda + cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e . + +build-vllm-rocm: + if [ ! -d 'vllm' ]; then \ + pip install -U ninja packaging --no-cache-dir && \ + git clone https://github.com/mht-sharma/vllm.git vllm; \ + fi + cd vllm && git fetch && git checkout $(commit_rocm) && \ + PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build + +install-vllm-rocm: build-vllm-rocm + cd vllm && git fetch && git checkout $(commit_rocm) && \ + PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e . diff --git a/backends/gaudi/server/README.md b/backends/gaudi/server/README.md new file mode 100644 index 000000000..b8208f9ea --- /dev/null +++ b/backends/gaudi/server/README.md @@ -0,0 +1,15 @@ +# Text Generation Inference Python gRPC Server + +A Python gRPC server for Text Generation Inference + +## Install + +```shell +make install +``` + +## Run + +```shell +make run-dev +``` diff --git a/backends/gaudi/server/custom_kernels/custom_kernels/fused_attention_cuda.cu b/backends/gaudi/server/custom_kernels/custom_kernels/fused_attention_cuda.cu new file mode 100644 index 000000000..60f9f0286 --- /dev/null +++ b/backends/gaudi/server/custom_kernels/custom_kernels/fused_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor query, + const at::Tensor key, + const at::Tensor value, + const std::optional> layer_past, + const at::Tensor attention_mask, + const std::optional head_mask, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + auto query_layer = query; + auto key_layer = key; + auto value_layer = value; + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 2); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + const auto batch_size = query_layer.size(0); + const auto q_length = query_layer.size(2); + const auto attn_head_size = query_layer.size(3); + const auto batch_size_times_num_heads = batch_size * num_heads; + const auto kv_length = key_layer.size(2); + + const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size}); + auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2); + auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}); + + auto query_scaled = query_view * inv_norm_factor; + auto attention_scores = at::bmm(query_scaled, key_view); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_view); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "GPT-Neox attention mechanism forward (CUDA)" + ); +} diff --git a/backends/gaudi/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu b/backends/gaudi/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu new file mode 100644 index 000000000..8206c3e03 --- /dev/null +++ b/backends/gaudi/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor fused_qkv, + const std::optional> layer_past, + const at::Tensor alibi, + const at::Tensor attention_mask, + const std::optional head_mask, + const float beta, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + const auto batch_size = fused_qkv.size(0); + const auto q_length = fused_qkv.size(1); + const auto three_times_hidden_size = fused_qkv.size(2); + const auto head_dim = three_times_hidden_size / (3 * num_heads); + const auto batch_size_times_num_heads = batch_size * num_heads; + + // `split_heads` + const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim}); + const auto tensor_list = fused_qkv_view.split(head_dim, -1); + const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length}); + auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 1); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + const auto kv_length = key_layer.size(2); + + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_layer); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "Bloom attention mechanism forward (CUDA)" + ); +} diff --git a/backends/gaudi/server/custom_kernels/setup.py b/backends/gaudi/server/custom_kernels/setup.py new file mode 100644 index 000000000..e0b83987c --- /dev/null +++ b/backends/gaudi/server/custom_kernels/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +extra_compile_args = ["-std=c++17"] + +setup( + name="custom_kernels", + ext_modules=[ + CUDAExtension( + name="custom_kernels.fused_bloom_attention_cuda", + sources=["custom_kernels/fused_bloom_attention_cuda.cu"], + extra_compile_args=extra_compile_args, + ), + CUDAExtension( + name="custom_kernels.fused_attention_cuda", + sources=["custom_kernels/fused_attention_cuda.cu"], + extra_compile_args=extra_compile_args, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/backends/gaudi/server/dill-0.3.7-patch.sh b/backends/gaudi/server/dill-0.3.7-patch.sh new file mode 100644 index 000000000..ad8c8be58 --- /dev/null +++ b/backends/gaudi/server/dill-0.3.7-patch.sh @@ -0,0 +1,91 @@ +#!/bin/bash +git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git +pushd dill +cat < dill-0.3.7.patch +diff --git a/dill/_dill.py b/dill/_dill.py +index d0cf543..f6eb662 100644 +--- a/dill/_dill.py ++++ b/dill/_dill.py +@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered + XRangeType = range + from types import MappingProxyType as DictProxyType, new_class + from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError +-import __main__ as _main_module ++class _LazyMainModule(object): ++ _module = None ++ @property ++ def module(self): ++ if self._module is None: ++ import __main__ as _m_module ++ self._module = _m_module ++ return self._module ++_main_module = _LazyMainModule() + import marshal + import gc + # import zlib +@@ -353,7 +361,7 @@ class Pickler(StockPickler): + _fmode = kwds.pop('fmode', None) + _recurse = kwds.pop('recurse', None) + StockPickler.__init__(self, file, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._diff_cache = {} + self._byref = settings['byref'] if _byref is None else _byref + self._strictio = False #_strictio +@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler): + settings = Pickler.settings + _ignore = kwds.pop('ignore', None) + StockUnpickler.__init__(self, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._ignore = settings['ignore'] if _ignore is None else _ignore + + def load(self): #NOTE: if settings change, need to update attributes + obj = StockUnpickler.load(self) +- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): ++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'): + if not self._ignore: + # point obj class to main + try: obj.__class__ = getattr(self._main, type(obj).__name__) +@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj): + logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) + logger.trace(pickler, "# D1") +- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): ++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__): + logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? + logger.trace(pickler, "# D3") +- elif '__name__' in obj and obj != _main_module.__dict__ \\ ++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\ + and type(obj['__name__']) is str \\ + and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): + logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj +diff --git a/dill/session.py b/dill/session.py +index 74234ab..1be8d89 100644 +--- a/dill/session.py ++++ b/dill/session.py +@@ -233,7 +233,7 @@ def dump_module( + protocol = settings['protocol'] + main = module + if main is None: +- main = _main_module ++ main = _main_module.module + elif isinstance(main, str): + main = _import_module(main) + if not isinstance(main, ModuleType): +@@ -501,7 +501,7 @@ def load_module( + pass + assert loaded is main + _restore_modules(unpickler, main) +- if main is _main_module or main is module: ++ if main is _main_module.module or main is module: + return None + else: + return main + +EOF +git apply dill-0.3.7.patch +python -m pip install . +popd +rm -fr dill diff --git a/backends/gaudi/server/dill-0.3.8-patch.sh b/backends/gaudi/server/dill-0.3.8-patch.sh new file mode 100644 index 000000000..da263960f --- /dev/null +++ b/backends/gaudi/server/dill-0.3.8-patch.sh @@ -0,0 +1,91 @@ +#!/bin/bash +git clone -b 0.3.8 https://github.com/uqfoundation/dill.git +pushd dill +cat < dill-0.3.8.patch +diff --git a/dill/_dill.py b/dill/_dill.py +index d42432f..1d251e6 100644 +--- a/dill/_dill.py ++++ b/dill/_dill.py +@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered + XRangeType = range + from types import MappingProxyType as DictProxyType, new_class + from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError +-import __main__ as _main_module ++class _LazyMainModule(object): ++ _module = None ++ @property ++ def module(self): ++ if self._module is None: ++ import __main__ as _m_module ++ self._module = _m_module ++ return self._module ++_main_module = _LazyMainModule() + import marshal + import gc + # import zlib +@@ -355,7 +363,7 @@ class Pickler(StockPickler): + _fmode = kwds.pop('fmode', None) + _recurse = kwds.pop('recurse', None) + StockPickler.__init__(self, file, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._diff_cache = {} + self._byref = settings['byref'] if _byref is None else _byref + self._strictio = False #_strictio +@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler): + settings = Pickler.settings + _ignore = kwds.pop('ignore', None) + StockUnpickler.__init__(self, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._ignore = settings['ignore'] if _ignore is None else _ignore + + def load(self): #NOTE: if settings change, need to update attributes + obj = StockUnpickler.load(self) +- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): ++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'): + if not self._ignore: + # point obj class to main + try: obj.__class__ = getattr(self._main, type(obj).__name__) +@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj): + logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) + logger.trace(pickler, "# D1") +- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): ++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__): + logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? + logger.trace(pickler, "# D3") +- elif '__name__' in obj and obj != _main_module.__dict__ \\ ++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\ + and type(obj['__name__']) is str \\ + and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): + logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj +diff --git a/dill/session.py b/dill/session.py +index e91068a..a921b43 100644 +--- a/dill/session.py ++++ b/dill/session.py +@@ -233,7 +233,7 @@ def dump_module( + protocol = settings['protocol'] + main = module + if main is None: +- main = _main_module ++ main = _main_module.module + elif isinstance(main, str): + main = _import_module(main) + if not isinstance(main, ModuleType): +@@ -501,7 +501,7 @@ def load_module( + pass + assert loaded is main + _restore_modules(unpickler, main) +- if main is _main_module or main is module: ++ if main is _main_module.module or main is module: + return None + else: + return main + +EOF +git apply dill-0.3.8.patch +python -m pip install . +popd +rm -fr dill \ No newline at end of file diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cu_compat.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cu_compat.cuh new file mode 100644 index 000000000..c5258813e --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cu new file mode 100644 index 000000000..ee2cbee23 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cu @@ -0,0 +1,71 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh new file mode 100644 index 000000000..afb60a012 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh @@ -0,0 +1,52 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu new file mode 100644 index 000000000..c25b0206b --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu @@ -0,0 +1,61 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "../util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh new file mode 100644 index 000000000..0364e38c4 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu new file mode 100644 index 000000000..1b0f79564 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -0,0 +1,256 @@ +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" +#include "../cu_compat.cuh" +#include "../cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "../hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + bool no_zero, + const cublasHandle_t handle +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); + +// const float alpha = 1.0f; +// const float beta = no_zero ? 1.0f : 0.0f; +// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, +// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh new file mode 100644 index 000000000..4c7a66692 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh @@ -0,0 +1,37 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "../tuning.h" + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + bool no_zero, + const cublasHandle_t handle +); + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu new file mode 100644 index 000000000..1f32e6b83 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -0,0 +1,220 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include "q4_matrix.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh new file mode 100644 index 000000000..49431dc95 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/exllama_ext.cpp b/backends/gaudi/server/exllama_kernels/exllama_kernels/exllama_ext.cpp new file mode 100644 index 000000000..f2df80e8f --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/exllama_ext.cpp @@ -0,0 +1,253 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "cuda_func/q4_matrix.cuh" +#include "cuda_func/q4_matmul.cuh" +#include "cuda_func/column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + false, + stream + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + false, + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/hip_compat.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/hip_compat.cuh new file mode 100644 index 000000000..f2a3dcada --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/hip_compat.cuh @@ -0,0 +1,52 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _hip_compat_cuh +#define _hip_compat_cuh + +// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{ + _Float16_2{static_cast<_Float16>(1.0f), + static_cast<_Float16>(1.0f)} / x.data}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf. +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_handle hipblasHandle_t +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream +#define rocblas_hgemm __compat_hipblasHgemm + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/matrix.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/matrix.cuh new file mode 100644 index 000000000..2fd5ab0b3 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/tuning.h b/backends/gaudi/server/exllama_kernels/exllama_kernels/tuning.h new file mode 100644 index 000000000..770ca46aa --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/backends/gaudi/server/exllama_kernels/exllama_kernels/util.cuh b/backends/gaudi/server/exllama_kernels/exllama_kernels/util.cuh new file mode 100644 index 000000000..7b3975732 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/exllama_kernels/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/backends/gaudi/server/exllama_kernels/setup.py b/backends/gaudi/server/exllama_kernels/setup.py new file mode 100644 index 000000000..cc307bf06 --- /dev/null +++ b/backends/gaudi/server/exllama_kernels/setup.py @@ -0,0 +1,32 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_cuda_cflags = [] +extra_cflags = [] +if torch.version.hip: + extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + +extra_compile_args = { + "cxx": extra_cflags, + "nvcc": extra_cuda_cflags, +} + +setup( + name="exllama_kernels", + ext_modules=[ + CUDAExtension( + name="exllama_kernels", + sources=[ + "exllama_kernels/exllama_ext.cpp", + "exllama_kernels/cuda_buffers.cu", + "exllama_kernels/cuda_func/column_remap.cu", + "exllama_kernels/cuda_func/q4_matmul.cu", + "exllama_kernels/cuda_func/q4_matrix.cu", + ], + extra_compile_args=extra_compile_args, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/config.h b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/config.h new file mode 100644 index 000000000..32a1a37d4 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/config.h @@ -0,0 +1,15 @@ +#ifndef _config_h +#define _config_h + +#define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS + +#define QMODE_2BIT 1 +#define QMODE_3BIT 1 +#define QMODE_4BIT 1 +#define QMODE_5BIT 1 +#define QMODE_6BIT 0 +#define QMODE_8BIT 0 + + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h new file mode 100644 index 000000000..919703a89 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h @@ -0,0 +1,12 @@ +#ifndef _util_h +#define _util_h + +#define DBGS(__x) printf("%s\n", __x) +#define DBGI(__x) printf("%s: %i\n", #__x, __x) +#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) +#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF(__x) printf("%s: %f\n", #__x, __x) +#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) +#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh new file mode 100644 index 000000000..12684ff8b --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh @@ -0,0 +1,56 @@ +#ifndef _compat_cuh +#define _compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh new file mode 100644 index 000000000..a72bc7bc1 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh @@ -0,0 +1,121 @@ +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "quant/qdq_util.cuh" + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu new file mode 100644 index 000000000..5b99f1ba8 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -0,0 +1,220 @@ +#include "q_gemm.cuh" +#include "util.cuh" +#include "matrix_view.cuh" +#include "../config.h" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define GPTQ_BLOCK_KN_SIZE 128 +#define GPTQ_BLOCK_M_SIZE_MAX 8 +#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32) + +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) + +#define CLEAR_N_SIZE 256 + +#include "q_gemm_kernel.cuh" +#include "q_gemm_kernel_gptq.cuh" + +void gemm_half_q_half_cuda_part +( + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + bool clear, + const half* r_weights, + int r_weights_stride, + bool mul_r_weights +) +{ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (!b->is_gptq) + { + dim3 blockDim, gridDim; + blockDim.x = EXL2_BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE); + + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); + + kernel<<>> + ( + a, + b->cuda_q_weight, + b->cuda_q_scale, + b->cuda_q_scale_max, + c, + size_m, + size_n, + size_k, + b->groups, + b->cuda_q_group_map, + b->cuda_q_perm, + b->rows_8, + b->rows_6, + b->rows_5, + b->rows_4, + b->rows_3, + b->rows_2, + clear, + r_weights, + r_weights_stride + ); + } + else + { + dim3 blockDim, gridDim; + blockDim.x = GPTQ_BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights); + +// DBGX((uint64_t) r_weights); +// if (r_weights) +// print_global_mem(r_weights, 1, 1, 1); +// DBGI(r_weights_stride); + + kernel<<>> + ( + a, + b->cuda_q_weight, + b->cuda_gptq_qzeros, + b->cuda_gptq_scales, + c, + size_m, + size_n, + size_k, + b->groups, + b->gptq_groupsize, + b->cuda_q_perm, + b->rows_4, + clear, + r_weights, + r_weights_stride + ); + } +} + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear, + half* temp_dq, + bool force_cuda, + const half* r_weights, + const int r_weights_stride, + bool mul_r_weights +) +{ + if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) + { + // Reconstruct FP16 matrix, then cuBLAS + + if (!temp_dq) temp_dq = b->temp_dq; + b->reconstruct(temp_dq); + + //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); + + const half alpha = __float2half(1.0f); + const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); + cublasHgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha, temp_dq, size_n, + a, size_k, + &beta, c, size_n); + + //const float alpha = 1.0f; + //const float beta = clear ? 0.0f : 1.0f; + //cublasSgemmEx(cublas_handle, + // CUBLAS_OP_N, + // CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n); + + //const float alpha = 1.0f; + //const float beta = clear ? 0.0f : 1.0f; + //cublasGemmEx(cublas_handle, + // CUBLAS_OP_N, CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n, + // CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP); + } + else + { + // Quantized matmul + + int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX; + int max_chunks = size_m / block_m_size_max; + int last_chunk = max_chunks * block_m_size_max; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) + { + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights); + } + + if (last_chunk_size) + { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights); + } + } +} + +__global__ void clear_kernel +( + half* __restrict__ c, + const int size_m, + const int size_n +) +{ + int m = blockIdx.y; + int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8; + if (n >= size_n) return; + int4* c_ptr = (int4*)(c + m * size_n + n); + *c_ptr = {}; +} + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +) +{ +// dim3 blockDim, gridDim; +// blockDim.x = CLEAR_N_SIZE; +// blockDim.y = 1; +// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); +// gridDim.y = size_m; +// clear_kernel<<>>(c, size_m, size_n); +} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh new file mode 100644 index 000000000..e49457f3a --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh @@ -0,0 +1,36 @@ +#ifndef _q_gemm_cuh +#define _q_gemm_cuh + +#include +#include +#include +#include +#include + +#include "q_matrix.cuh" + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear = false, + half* reconstruct = NULL, + bool force_cuda = false, + const half* r_weights = NULL, + const int r_weights_stride = 0, + bool mul_r_weights = false +); + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +); + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh new file mode 100644 index 000000000..9cd2ba015 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh @@ -0,0 +1,580 @@ +#include "compat.cuh" + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) +{ + // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 + + float result = {}; + #pragma unroll + for (int i = 0; i < 4; i++) + { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + + +typedef void (*fp_gemm_half_q_half_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const uint16_t*, + const uint16_t*, + const int, + const int, + const int, + const int, + const int, + const int, + const bool, + const half*, + const int +); + +template +__global__ void gemm_half_q_half_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const uint16_t* __restrict__ b_q_group_map, + const uint16_t* __restrict__ b_q_perm, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2, + const bool clear, + const half* r_weights, + const int r_weights_stride +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; + + int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k); + int n = offset_n + t * 4; + + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + + // Preload block_a + + __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + half a0 = a_ptr[b_q_perm[offset_k + t]]; +// half a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Clear + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + //int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + +// if (offset_m == 0 && t == 0) +// DBGI2(offset_k, group); + + // Preload scales + + half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; + + //int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + int temp_k = offset_k; + for (int g = 0; temp_k < end_k; g++) + { + int qscales[4]; + b_q_scale_.item4(qscales, group + g, n); + qscales[0]++; + qscales[1]++; + qscales[2]++; + qscales[3]++; + half maxscale = b_q_scale_max[group + g]; + scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); + scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); + scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); + scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); + temp_k += b_q_group_map[temp_k * 2 + 1]; + } + + // a, b offset + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = EXL2_BLOCK_KN_SIZE; + + // Initial group + + int scales_idx = 0; + half qs_h0 = scales[scales_idx][0]; + half qs_h1 = scales[scales_idx][1]; + half qs_h2 = scales[scales_idx][2]; + half qs_h3 = scales[scales_idx][3]; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + // Column result + + half block_c[m_count][4] = {}; + + // Dequantize groups + + int k = offset_k; + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[2]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 2; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 16; + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[5]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[3] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[4] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n); + dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n); + dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n); + dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_4bit_8(load_int4[0].x, dq[0], size_n); + dequant_4bit_8(load_int4[0].y, dq[1], size_n); + dequant_4bit_8(load_int4[0].z, dq[2], size_n); + dequant_4bit_8(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_2bit_16(load_int4[0].x, dq[0], size_n); + dequant_2bit_16(load_int4[0].y, dq[1], size_n); + dequant_2bit_16(load_int4[0].z, dq[2], size_n); + dequant_2bit_16(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + + a_ptr += 16; + } + k += 16; + } + + // Accumulate column sums in c + + for (int m = 0; m < m_count; m++) + { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + + atomicAdd(out , result01); + atomicAdd(out + 1, result23); +// *out = result01; +// *(out + 1) = result23; + } +} + +template +struct map_m_count_exl2 { + static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) + { + #if EXL2_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights) +{ + if (!r_weights && !mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count); + return NULL; +} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh new file mode 100644 index 000000000..f816fd9dd --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -0,0 +1,273 @@ +#include "compat.cuh" + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return result; +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int, + const uint16_t*, + const int, + const bool, + const half*, + const int +); + +template +__global__ void gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int groupsize, + const uint16_t* __restrict__ b_q_perm, + const int rows_4, + const bool clear, + const half* r_weights, + const int r_weights_stride +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE; + + int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + + // Preload block_a + + __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = GPTQ_BLOCK_KN_SIZE; + + // Initial group + + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + +// __syncthreads(); + + // Column result + + half2 block_c[m_count][4] = {}; + + // Dequantize and multiply + + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0])); + half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1])); + half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2])); + half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3])); + half2 result01 = __halves2half2(result0, result1); + half2 result23 = __halves2half2(result2, result3); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +template +struct map_m_count_gptq { + static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count) + { + #if GPTQ_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights) +{ + if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); + return NULL; +} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu new file mode 100644 index 000000000..f7a91e29e --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -0,0 +1,650 @@ +#include "q_matrix.cuh" +#include "matrix_view.cuh" +#include "util.cuh" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define BLOCK_KN_SIZE 128 + +#define THREADS_X 32 +#define THREADS_Y 32 + +// Shuffle quantized data on load + +__global__ void shuffle_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2 +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } + while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; } + while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; } + while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } + while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } + while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +} + + +// QMatrix constructor + +QMatrix::QMatrix +( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + uint16_t* _q_group_map, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq +) : + device(_device), + height(_height), + width(_width), + groups(_groups), + temp_dq(_temp_dq) +{ + cudaSetDevice(device); + + failed = false; + + cuda_q_weight = _q_weight; + cuda_q_perm = _q_perm; + cuda_q_invperm = _q_invperm; + cuda_q_scale = _q_scale; + cuda_q_scale_max = _q_scale_max; + cuda_q_groups = _q_groups; + cuda_q_group_map = _q_group_map; + cuda_gptq_qzeros = _gptq_qzeros; + cuda_gptq_scales = _gptq_scales; + + is_gptq = (_gptq_qzeros != NULL); + + if (is_gptq) + { + gptq_groupsize = 1; + while (gptq_groupsize * groups < height) gptq_groupsize *= 2; + } + + // Create group map + + rows_8 = 0; + rows_6 = 0; + rows_5 = 0; + rows_4 = 0; + rows_3 = 0; + rows_2 = 0; + + if (!is_gptq) + { + uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); + cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); + + int row = 0; + for (int i = 0; i < groups; i++) + { + int bits = cpu_q_groups[i * 2]; + + int rows; + if (i < groups - 1) + { + int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; + rows = qrows * 32 / bits; + } + else rows = height - row; + + if (bits == 8) rows_8 += rows; + if (bits == 6) rows_6 += rows; + if (bits == 5) rows_5 += rows; + if (bits == 4) rows_4 += rows; + if (bits == 3) rows_3 += rows; + if (bits == 2) rows_2 += rows; + row += rows; + } + + free(cpu_q_groups); + + rows_6 += rows_8; + rows_5 += rows_6; + rows_4 += rows_5; + rows_3 += rows_4; + rows_2 += rows_3; + } + else + { + rows_4 = height; + rows_3 = height; + rows_2 = height; + + if (_gptq_g_idx) + { + if (!make_sequential(_gptq_g_idx)) + { + failed = true; + //printf("FAIL\n"); + return; + } + } + } + +// DBGI(rows_8); +// DBGI(rows_6); +// DBGI(rows_5); +// DBGI(rows_4); +// DBGI(rows_3); +// DBGI(rows_2); + + // Shuffle quantized data + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); +} + +QMatrix::~QMatrix() +{ +} + +// Reconstruct b[k,n] (GPTQ) + +__global__ void reconstruct_gptq_kernel +( + const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + //const uint16_t* __restrict__ b_q_groups, + const int size_k, + const int size_n, + const int groupsize, + const int groups, + half* __restrict__ b, + const int rows_4 +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + + +// Reconstruct b[k,n] + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + const uint16_t* __restrict__ b_q_group_map, + const int size_k, + const int size_n, + //const int groupsize, + const int groups, + half* __restrict__ b, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2 +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x; + + // Preload remapping table + + int t = threadIdx.x; + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + + // Column + + int n = offset_n + t; + if (n >= size_n) return; + + // Find initial group + + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + half2 qs_h2 = __halves2half2(qs_h, qs_h); + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int k = offset_k; + int lk = 0; + + __syncthreads(); + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + dequant_8bit_8(q_0, q_1, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 2; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + dequant_6bit_16(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + uint32_t q_3 = *b_ptr; b_ptr += size_n; + uint32_t q_4 = *b_ptr; b_ptr += size_n; + dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + dequant_4bit_8(q_0, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + dequant_3bit_32(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + dequant_2bit_16(q_0, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 16; + } +} + +void QMatrix::reconstruct(half* out) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (!is_gptq) + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_kernel<<>> + ( + cuda_q_weight, + cuda_q_perm, + cuda_q_scale, + cuda_q_scale_max, + cuda_q_group_map, + height, + width, + //groupsize, + groups, + out, + rows_8, + rows_6, + rows_5, + rows_4, + rows_3, + rows_2 + ); + } + else + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); + reconstruct_gptq_kernel<<>> + ( + cuda_q_weight, + cuda_q_perm, + cuda_gptq_qzeros, + cuda_gptq_scales, + //const uint16_t* __restrict__ b_q_groups, + height, + width, + gptq_groupsize, + groups, + out, + rows_4 + ); + } +} + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint16_t* __restrict__ q_perm, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int q_perm_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) +{ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + uint32_t* cuda_new_qweight = NULL; + cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + if (err != cudaSuccess) { + cudaError_t cuda_status = cudaGetLastError(); // Clear error + return false; + } + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Reduce to uint16_t + + uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map; + uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv; + for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row]; + for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row]; + + // Move to CUDA + + cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 8; + + make_sequential_kernel<<>> + ( + cuda_q_weight, + cuda_new_qweight, + cuda_q_perm, + height / 8, + width + ); + + // Replace qweights + + cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); + + return true; +} diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh new file mode 100644 index 000000000..d36b8d667 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh @@ -0,0 +1,75 @@ +#ifndef _q_matrix_cuh +#define _q_matrix_cuh + +#include +#include +#include +#include + +#define MAX_SUPERGROUPS 16 + +class QMatrix +{ +public: + + int device; + bool is_gptq; + + int height; + int width; + int groups; + int gptq_groupsize; + + int rows_8; + int rows_6; + int rows_5; + int rows_4; + int rows_3; + int rows_2; + + uint32_t* cuda_q_weight = NULL; + uint16_t* cuda_q_perm = NULL; + uint16_t* cuda_q_invperm = NULL; + uint32_t* cuda_q_scale = NULL; + half* cuda_q_scale_max = NULL; + uint16_t* cuda_q_groups = NULL; + uint16_t* cuda_q_group_map = NULL; + uint32_t* cuda_gptq_qzeros = NULL; + half* cuda_gptq_scales = NULL; + + half* temp_dq; + + bool failed; + + QMatrix + ( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + uint16_t* _q_group_map, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq + ); + + ~QMatrix(); + + void reconstruct(half* out); + bool make_sequential(const uint32_t* cpu_g_idx); + +private: + +}; + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh new file mode 100644 index 000000000..90c18a0c6 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh @@ -0,0 +1,103 @@ +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_2BIT == 1 + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 2.0f); + const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z4 = __halves2half2(z4_, z4_); + const half2 z16 = __halves2half2(z16_, z16_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +#else + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh new file mode 100644 index 000000000..101173763 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh @@ -0,0 +1,169 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_3BIT == 1 + +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 4.0f); + const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y8, z8); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hfma2( q3.as_half2, y8, z8); + dq[ 4] = __hfma2( q4.as_half2, y64, z64); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hfma2( q6.as_half2, y8, z8); + dq[ 7] = __hadd2( q7.as_half2, z1); + dq[ 8] = __hfma2( q8.as_half2, y8, z8); + dq[ 9] = __hfma2( q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#else + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + half dqh[32]; + for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4); + dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4); + for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4); + dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4); + for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4); + + for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh new file mode 100644 index 000000000..ad95edb46 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh @@ -0,0 +1,227 @@ +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_4BIT == 1 + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +#else + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1)[2], + half2 (&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z = __hmul(z, scale); + z1[0] = __half2half2(z); + y1[0] = __half2half2(scale); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1)[2], + half2(&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z1[0] = __half2half2(z); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1)[2], + half2 (&y1)[2], + int stride, + bool scaled +) +{ + half2 dqh2[8]; + + uint32_t qa = q_0; + for (int i = 0; i < 4; i++) + { + half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; + half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; + dqh2[i] = __halves2half2(d0, d1); + } + + if (scaled) + { + dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); + dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); + dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); + dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); + } + else + { + dq[0] = __hadd2(dqh2[0], z1[0]); + dq[1] = __hadd2(dqh2[1], z1[0]); + dq[2] = __hadd2(dqh2[2], z1[0]); + dq[3] = __hadd2(dqh2[3], z1[0]); + } +} + +#endif + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh new file mode 100644 index 000000000..78d81f92e --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh @@ -0,0 +1,207 @@ +#ifndef _qdq_5_cuh +#define _qdq_5_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_5BIT == 1 + +// Permutation: +// +// v5555533 33311111 u4444422 22200000 (u, v lsb) +// vbbbbb99 99977777 uaaaaa88 88866666 +// vhhhhhff fffddddd ugggggee eeeccccc +// vnnnnnll llljjjjj ummmmmkk kkkiiiii +// vtttttrr rrrppppp usssssqq qqqooooo + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + uint32_t qd = q[3 * stride]; + uint32_t qe = q[4 * stride]; + + // qa: 66555554 44443333 32222211 11100000 + // qb: ccccbbbb baaaaa99 99988888 77777666 + // qc: jiiiiihh hhhggggg fffffeee eedddddc + // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj + // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp + + uint32_t qf = qe >> 22; + qe <<= 8; + qe |= qd >> 24; + qd <<= 6; + qd |= qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: 555554 44443333 32222211 11100000 + // qb: bbbbba aaaa9999 98888877 77766666 + // qc: hhhhhg ggggffff feeeeedd dddccccc + // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii + // qe: ttttts ssssrrrr rqqqqqpp pppooooo + // qf: vv vvvuuuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + uint32_t zd = 0; + uint32_t ze = 0; + + for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); } + + // za: 5555533 33311111 4444422 22200000 + // zb: bbbbb99 99977777 aaaaa88 88866666 + // zc: hhhhhff fffddddd gggggee eeeccccc + // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii + // ze: tttttrr rrrppppp sssssqq qqqooooo + // qf: vv vvvuuuuu + + za |= ((qf & 0x001) >> 0) << 15; + zb |= ((qf & 0x002) >> 1) << 15; + zc |= ((qf & 0x004) >> 2) << 15; + zd |= ((qf & 0x008) >> 3) << 15; + ze |= ((qf & 0x010) >> 4) << 15; + za |= ((qf & 0x020) >> 5) << 31; + zb |= ((qf & 0x040) >> 6) << 31; + zc |= ((qf & 0x080) >> 7) << 31; + zd |= ((qf & 0x100) >> 8) << 31; + ze |= ((qf & 0x200) >> 9) << 31; + + // za: v5555533 33311111 u4444422 22200000 (u, v lsb) + // zb: vbbbbb99 99977777 uaaaaa88 88866666 + // zc: vhhhhhff fffddddd ugggggee eeeccccc + // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii + // ze: vtttttrr rrrppppp usssssqq qqqooooo + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; + q[3 * stride] = zd; + q[4 * stride] = ze; +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y32_ = __float2half_rn(1.0f / 32.0f); + const half2 y32 = __halves2half2(y32_, y32_); + const half z1_ = __float2half_rn(-1024.0f - 16.0f); + const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z32 = __halves2half2(z32_, z32_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + uint32_t qd = q_3; + uint32_t qe = q_4; + + half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024 + qa >>= 10; + half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024 + qa >>= 5; + qa &= 0x00010001; + half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024 + half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024 + qb >>= 10; + half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024 + qb >>= 4; + qb &= 0x00020002; + half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024 + half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024 + qc >>= 10; + half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024 + qc >>= 3; + qc &= 0x00040004; + half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024 + half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024 + qd >>= 10; + half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024 + qd >>= 2; + qd &= 0x00080008; + half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024 + qe >>= 10; + half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024 + qe >>= 1; + qe &= 0x00100010; + half2_uint32 q15((qa | qb | qc | qd | qe) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y32, z32); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hadd2( q3.as_half2, z1); + dq[ 4] = __hfma2( q4.as_half2, y32, z32); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hadd2( q6.as_half2, z1); + dq[ 7] = __hfma2( q7.as_half2, y32, z32); + dq[ 8] = __hadd2( q8.as_half2, z1); + dq[ 9] = __hadd2( q9.as_half2, z1); + dq[10] = __hfma2(q10.as_half2, y32, z32); + dq[11] = __hadd2(q11.as_half2, z1); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y32, z32); + dq[14] = __hadd2(q14.as_half2, z1); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#else + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + half dqh[32]; + for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16); + dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16); + for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16); + dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16); + for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16); + dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16); + for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16); + dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16); + for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16); + + for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh new file mode 100644 index 000000000..562fe695f --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh @@ -0,0 +1,42 @@ +#ifndef _qdq_6_cuh +#define _qdq_6_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_6BIT == 1 + + // Not implemented + +#else + +__forceinline__ __device__ void shuffle_6bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_6bit_16 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32); + dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); + for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32); + dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); + for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh new file mode 100644 index 000000000..6e6bedbdf --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh @@ -0,0 +1,38 @@ +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_8BIT == 1 + + // Not implemented + +#else + +__forceinline__ __device__ void shuffle_8bit_4 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_8bit_8 +( + const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh new file mode 100644 index 000000000..cac9df9c1 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh @@ -0,0 +1,53 @@ +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh new file mode 100644 index 000000000..e167bc23e --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh @@ -0,0 +1,54 @@ +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include +#include + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#define DBGS(__x) printf("%s\n", __x) +#define DBGI(__x) printf("%s: %i\n", #__x, __x) +#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) +#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGX(__x) printf("%s: %x\n", #__x, __x) +#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y) +#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF(__x) printf("%s: %f\n", #__x, __x) +#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) +#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x)) +#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y)) +#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z)) + +#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y)) +#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z)) + +__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale) +{ + half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f)); + qs_h = __hmul(qs_h, qs_h); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ float clamp(float x, float a, float b) +{ + return fmaxf(a, fminf(b, x)); +} + +#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } +inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} + +void print_global_mem(const half* ptr, int rows, int columns, int stride); + +#endif diff --git a/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/ext.cpp b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/ext.cpp new file mode 100644 index 000000000..ff4e1851f --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/exllamav2_kernels/ext.cpp @@ -0,0 +1,139 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "config.h" + +#include "cuda/q_matrix.cuh" +#include "cuda/q_gemm.cuh" + +#include "cpp/util.h" + +// Some decluttering macros + +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") + + +// Quant matrix + +uintptr_t make_q_matrix +( + torch::Tensor q_weight, + torch::Tensor q_perm, + torch::Tensor q_invperm, + torch::Tensor q_scale, + torch::Tensor q_scale_max, + torch::Tensor q_groups, + torch::Tensor q_group_map, + torch::Tensor gptq_qzeros, + torch::Tensor gptq_scales, + torch::Tensor gptq_g_idx, + torch::Tensor temp_dq +) +{ + TORCH_CHECK_DTYPE(q_weight, kInt); + TORCH_CHECK_DTYPE_OPT(q_perm, kShort); + TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); + TORCH_CHECK_DTYPE_OPT(q_scale, kInt); + TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); + TORCH_CHECK_DTYPE_OPT(q_groups, kShort); + TORCH_CHECK_DTYPE_OPT(q_group_map, kShort); + TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); + TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); + TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); + + TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); + + int device = q_weight.device().index(); + int width = q_weight.size(1); + int groups; + int height; + + if (!q_scale.device().is_meta()) + { + TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8); + TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1); + groups = q_scale.size(0); + height = q_invperm.size(0); + } + else + { + TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8); + TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1); + groups = gptq_qzeros.size(0); + height = q_weight.size(0) * 8; + } + + TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer") + + QMatrix* m = new QMatrix + ( + device, + height, + width, + groups, + (uint32_t*) q_weight.data_ptr(), + q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), + q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), + q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), + q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), + q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), + q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(), + gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), + gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), + gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), + (half*) temp_dq.data_ptr() + ); + + if (m->failed) throw std::runtime_error("CUDA out of memory"); + + return reinterpret_cast (m); +} + +void gemm_half_q_half +( + torch::Tensor a, + uintptr_t b, + torch::Tensor c, + bool force_cuda +) +{ + QMatrix* qm = reinterpret_cast (b); + + TORCH_CHECK_DTYPE(a, kHalf); + TORCH_CHECK_DTYPE(c, kHalf); + TORCH_CHECK_SHAPES(a, 0, c, 0, 1); + TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes") + TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + + gemm_half_q_half_cuda + ( + at::cuda::getCurrentCUDABlasHandle(), + (const half*) a.data_ptr(), + qm, + (half*) c.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + true, + NULL, + force_cuda + ); +} + +// Bindings + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("make_q_matrix", &make_q_matrix, "make_q_matrix"); + m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half"); +} diff --git a/backends/gaudi/server/exllamav2_kernels/setup.py b/backends/gaudi/server/exllamav2_kernels/setup.py new file mode 100644 index 000000000..56ffa9733 --- /dev/null +++ b/backends/gaudi/server/exllamav2_kernels/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_cuda_cflags = ["-lineinfo", "-O3"] +extra_cflags = [] +if torch.version.hip: + extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"] + +extra_compile_args = { + "cxx": extra_cflags, + "nvcc": extra_cuda_cflags, +} + +setup( + name="exllamav2_kernels", + ext_modules=[ + CUDAExtension( + name="exllamav2_kernels", + sources=[ + "exllamav2_kernels/ext.cpp", + "exllamav2_kernels/cuda/q_matrix.cu", + "exllamav2_kernels/cuda/q_gemm.cu", + ], + extra_compile_args=extra_compile_args, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/backends/gaudi/server/poetry.lock b/backends/gaudi/server/poetry.lock new file mode 100644 index 000000000..1987c3b31 --- /dev/null +++ b/backends/gaudi/server/poetry.lock @@ -0,0 +1,4019 @@ +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. + +[[package]] +name = "accelerate" +version = "0.29.3" +description = "Accelerate" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "accelerate-0.29.3-py3-none-any.whl", hash = "sha256:99d633d4b6126817c5e554487406748be95c8d1d1e659dd2fd60657e35f532dd"}, + {file = "accelerate-0.29.3.tar.gz", hash = "sha256:1a5a845b06b24b41736b219b2b20fd021ca5dff4070a252445fd6de736e347ac"}, +] + +[package.dependencies] +huggingface-hub = "*" +numpy = ">=1.17" +packaging = ">=20.0" +psutil = "*" +pyyaml = "*" +safetensors = ">=0.3.1" +torch = ">=1.10.0" + +[package.extras] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"] +rich = ["rich"] +sagemaker = ["sagemaker"] +test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"] +test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] +testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] + +[[package]] +name = "aiohappyeyeballs" +version = "2.4.3" +description = "Happy Eyeballs for asyncio" +optional = true +python-versions = ">=3.8" +files = [ + {file = "aiohappyeyeballs-2.4.3-py3-none-any.whl", hash = "sha256:8a7a83727b2756f394ab2895ea0765a0a8c475e3c71e98d43d76f22b4b435572"}, + {file = "aiohappyeyeballs-2.4.3.tar.gz", hash = "sha256:75cf88a15106a5002a8eb1dab212525c00d1f4c0fa96e551c9fbe6f09a621586"}, +] + +[[package]] +name = "aiohttp" +version = "3.10.10" +description = "Async http client/server framework (asyncio)" +optional = true +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.10.10-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:be7443669ae9c016b71f402e43208e13ddf00912f47f623ee5994e12fc7d4b3f"}, + {file = "aiohttp-3.10.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b06b7843929e41a94ea09eb1ce3927865387e3e23ebe108e0d0d09b08d25be9"}, + {file = "aiohttp-3.10.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:333cf6cf8e65f6a1e06e9eb3e643a0c515bb850d470902274239fea02033e9a8"}, + {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:274cfa632350225ce3fdeb318c23b4a10ec25c0e2c880eff951a3842cf358ac1"}, + {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9e5e4a85bdb56d224f412d9c98ae4cbd032cc4f3161818f692cd81766eee65a"}, + {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b606353da03edcc71130b52388d25f9a30a126e04caef1fd637e31683033abd"}, + {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab5a5a0c7a7991d90446a198689c0535be89bbd6b410a1f9a66688f0880ec026"}, + {file = "aiohttp-3.10.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:578a4b875af3e0daaf1ac6fa983d93e0bbfec3ead753b6d6f33d467100cdc67b"}, + {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8105fd8a890df77b76dd3054cddf01a879fc13e8af576805d667e0fa0224c35d"}, + {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3bcd391d083f636c06a68715e69467963d1f9600f85ef556ea82e9ef25f043f7"}, + {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fbc6264158392bad9df19537e872d476f7c57adf718944cc1e4495cbabf38e2a"}, + {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e48d5021a84d341bcaf95c8460b152cfbad770d28e5fe14a768988c461b821bc"}, + {file = "aiohttp-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2609e9ab08474702cc67b7702dbb8a80e392c54613ebe80db7e8dbdb79837c68"}, + {file = "aiohttp-3.10.10-cp310-cp310-win32.whl", hash = "sha256:84afcdea18eda514c25bc68b9af2a2b1adea7c08899175a51fe7c4fb6d551257"}, + {file = "aiohttp-3.10.10-cp310-cp310-win_amd64.whl", hash = "sha256:9c72109213eb9d3874f7ac8c0c5fa90e072d678e117d9061c06e30c85b4cf0e6"}, + {file = "aiohttp-3.10.10-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c30a0eafc89d28e7f959281b58198a9fa5e99405f716c0289b7892ca345fe45f"}, + {file = "aiohttp-3.10.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:258c5dd01afc10015866114e210fb7365f0d02d9d059c3c3415382ab633fcbcb"}, + {file = "aiohttp-3.10.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:15ecd889a709b0080f02721255b3f80bb261c2293d3c748151274dfea93ac871"}, + {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3935f82f6f4a3820270842e90456ebad3af15810cf65932bd24da4463bc0a4c"}, + {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:413251f6fcf552a33c981c4709a6bba37b12710982fec8e558ae944bfb2abd38"}, + {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1720b4f14c78a3089562b8875b53e36b51c97c51adc53325a69b79b4b48ebcb"}, + {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:679abe5d3858b33c2cf74faec299fda60ea9de62916e8b67e625d65bf069a3b7"}, + {file = "aiohttp-3.10.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79019094f87c9fb44f8d769e41dbb664d6e8fcfd62f665ccce36762deaa0e911"}, + {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe2fb38c2ed905a2582948e2de560675e9dfbee94c6d5ccdb1301c6d0a5bf092"}, + {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a3f00003de6eba42d6e94fabb4125600d6e484846dbf90ea8e48a800430cc142"}, + {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:1bbb122c557a16fafc10354b9d99ebf2f2808a660d78202f10ba9d50786384b9"}, + {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:30ca7c3b94708a9d7ae76ff281b2f47d8eaf2579cd05971b5dc681db8caac6e1"}, + {file = "aiohttp-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:df9270660711670e68803107d55c2b5949c2e0f2e4896da176e1ecfc068b974a"}, + {file = "aiohttp-3.10.10-cp311-cp311-win32.whl", hash = "sha256:aafc8ee9b742ce75044ae9a4d3e60e3d918d15a4c2e08a6c3c3e38fa59b92d94"}, + {file = "aiohttp-3.10.10-cp311-cp311-win_amd64.whl", hash = "sha256:362f641f9071e5f3ee6f8e7d37d5ed0d95aae656adf4ef578313ee585b585959"}, + {file = "aiohttp-3.10.10-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9294bbb581f92770e6ed5c19559e1e99255e4ca604a22c5c6397b2f9dd3ee42c"}, + {file = "aiohttp-3.10.10-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a8fa23fe62c436ccf23ff930149c047f060c7126eae3ccea005f0483f27b2e28"}, + {file = "aiohttp-3.10.10-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c6a5b8c7926ba5d8545c7dd22961a107526562da31a7a32fa2456baf040939f"}, + {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:007ec22fbc573e5eb2fb7dec4198ef8f6bf2fe4ce20020798b2eb5d0abda6138"}, + {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9627cc1a10c8c409b5822a92d57a77f383b554463d1884008e051c32ab1b3742"}, + {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:50edbcad60d8f0e3eccc68da67f37268b5144ecc34d59f27a02f9611c1d4eec7"}, + {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a45d85cf20b5e0d0aa5a8dca27cce8eddef3292bc29d72dcad1641f4ed50aa16"}, + {file = "aiohttp-3.10.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b00807e2605f16e1e198f33a53ce3c4523114059b0c09c337209ae55e3823a8"}, + {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f2d4324a98062be0525d16f768a03e0bbb3b9fe301ceee99611dc9a7953124e6"}, + {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:438cd072f75bb6612f2aca29f8bd7cdf6e35e8f160bc312e49fbecab77c99e3a"}, + {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:baa42524a82f75303f714108fea528ccacf0386af429b69fff141ffef1c534f9"}, + {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a7d8d14fe962153fc681f6366bdec33d4356f98a3e3567782aac1b6e0e40109a"}, + {file = "aiohttp-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c1277cd707c465cd09572a774559a3cc7c7a28802eb3a2a9472588f062097205"}, + {file = "aiohttp-3.10.10-cp312-cp312-win32.whl", hash = "sha256:59bb3c54aa420521dc4ce3cc2c3fe2ad82adf7b09403fa1f48ae45c0cbde6628"}, + {file = "aiohttp-3.10.10-cp312-cp312-win_amd64.whl", hash = "sha256:0e1b370d8007c4ae31ee6db7f9a2fe801a42b146cec80a86766e7ad5c4a259cf"}, + {file = "aiohttp-3.10.10-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ad7593bb24b2ab09e65e8a1d385606f0f47c65b5a2ae6c551db67d6653e78c28"}, + {file = "aiohttp-3.10.10-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1eb89d3d29adaf533588f209768a9c02e44e4baf832b08118749c5fad191781d"}, + {file = "aiohttp-3.10.10-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3fe407bf93533a6fa82dece0e74dbcaaf5d684e5a51862887f9eaebe6372cd79"}, + {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50aed5155f819873d23520919e16703fc8925e509abbb1a1491b0087d1cd969e"}, + {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f05e9727ce409358baa615dbeb9b969db94324a79b5a5cea45d39bdb01d82e6"}, + {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dffb610a30d643983aeb185ce134f97f290f8935f0abccdd32c77bed9388b42"}, + {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa6658732517ddabe22c9036479eabce6036655ba87a0224c612e1ae6af2087e"}, + {file = "aiohttp-3.10.10-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:741a46d58677d8c733175d7e5aa618d277cd9d880301a380fd296975a9cdd7bc"}, + {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e00e3505cd80440f6c98c6d69269dcc2a119f86ad0a9fd70bccc59504bebd68a"}, + {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ffe595f10566f8276b76dc3a11ae4bb7eba1aac8ddd75811736a15b0d5311414"}, + {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:bdfcf6443637c148c4e1a20c48c566aa694fa5e288d34b20fcdc58507882fed3"}, + {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d183cf9c797a5291e8301790ed6d053480ed94070637bfaad914dd38b0981f67"}, + {file = "aiohttp-3.10.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:77abf6665ae54000b98b3c742bc6ea1d1fb31c394bcabf8b5d2c1ac3ebfe7f3b"}, + {file = "aiohttp-3.10.10-cp313-cp313-win32.whl", hash = "sha256:4470c73c12cd9109db8277287d11f9dd98f77fc54155fc71a7738a83ffcc8ea8"}, + {file = "aiohttp-3.10.10-cp313-cp313-win_amd64.whl", hash = "sha256:486f7aabfa292719a2753c016cc3a8f8172965cabb3ea2e7f7436c7f5a22a151"}, + {file = "aiohttp-3.10.10-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:1b66ccafef7336a1e1f0e389901f60c1d920102315a56df85e49552308fc0486"}, + {file = "aiohttp-3.10.10-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:acd48d5b80ee80f9432a165c0ac8cbf9253eaddb6113269a5e18699b33958dbb"}, + {file = "aiohttp-3.10.10-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3455522392fb15ff549d92fbf4b73b559d5e43dc522588f7eb3e54c3f38beee7"}, + {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45c3b868724137f713a38376fef8120c166d1eadd50da1855c112fe97954aed8"}, + {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:da1dee8948d2137bb51fbb8a53cce6b1bcc86003c6b42565f008438b806cccd8"}, + {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5ce2ce7c997e1971b7184ee37deb6ea9922ef5163c6ee5aa3c274b05f9e12fa"}, + {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28529e08fde6f12eba8677f5a8608500ed33c086f974de68cc65ab218713a59d"}, + {file = "aiohttp-3.10.10-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7db54c7914cc99d901d93a34704833568d86c20925b2762f9fa779f9cd2e70f"}, + {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:03a42ac7895406220124c88911ebee31ba8b2d24c98507f4a8bf826b2937c7f2"}, + {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:7e338c0523d024fad378b376a79faff37fafb3c001872a618cde1d322400a572"}, + {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:038f514fe39e235e9fef6717fbf944057bfa24f9b3db9ee551a7ecf584b5b480"}, + {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:64f6c17757251e2b8d885d728b6433d9d970573586a78b78ba8929b0f41d045a"}, + {file = "aiohttp-3.10.10-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:93429602396f3383a797a2a70e5f1de5df8e35535d7806c9f91df06f297e109b"}, + {file = "aiohttp-3.10.10-cp38-cp38-win32.whl", hash = "sha256:c823bc3971c44ab93e611ab1a46b1eafeae474c0c844aff4b7474287b75fe49c"}, + {file = "aiohttp-3.10.10-cp38-cp38-win_amd64.whl", hash = "sha256:54ca74df1be3c7ca1cf7f4c971c79c2daf48d9aa65dea1a662ae18926f5bc8ce"}, + {file = "aiohttp-3.10.10-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:01948b1d570f83ee7bbf5a60ea2375a89dfb09fd419170e7f5af029510033d24"}, + {file = "aiohttp-3.10.10-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9fc1500fd2a952c5c8e3b29aaf7e3cc6e27e9cfc0a8819b3bce48cc1b849e4cc"}, + {file = "aiohttp-3.10.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f614ab0c76397661b90b6851a030004dac502e48260ea10f2441abd2207fbcc7"}, + {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00819de9e45d42584bed046314c40ea7e9aea95411b38971082cad449392b08c"}, + {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05646ebe6b94cc93407b3bf34b9eb26c20722384d068eb7339de802154d61bc5"}, + {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:998f3bd3cfc95e9424a6acd7840cbdd39e45bc09ef87533c006f94ac47296090"}, + {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9010c31cd6fa59438da4e58a7f19e4753f7f264300cd152e7f90d4602449762"}, + {file = "aiohttp-3.10.10-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ea7ffc6d6d6f8a11e6f40091a1040995cdff02cfc9ba4c2f30a516cb2633554"}, + {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ef9c33cc5cbca35808f6c74be11eb7f5f6b14d2311be84a15b594bd3e58b5527"}, + {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ce0cdc074d540265bfeb31336e678b4e37316849d13b308607efa527e981f5c2"}, + {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:597a079284b7ee65ee102bc3a6ea226a37d2b96d0418cc9047490f231dc09fe8"}, + {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:7789050d9e5d0c309c706953e5e8876e38662d57d45f936902e176d19f1c58ab"}, + {file = "aiohttp-3.10.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e7f8b04d83483577fd9200461b057c9f14ced334dcb053090cea1da9c8321a91"}, + {file = "aiohttp-3.10.10-cp39-cp39-win32.whl", hash = "sha256:c02a30b904282777d872266b87b20ed8cc0d1501855e27f831320f471d54d983"}, + {file = "aiohttp-3.10.10-cp39-cp39-win_amd64.whl", hash = "sha256:edfe3341033a6b53a5c522c802deb2079eee5cbfbb0af032a55064bd65c73a23"}, + {file = "aiohttp-3.10.10.tar.gz", hash = "sha256:0631dd7c9f0822cc61c88586ca76d5b5ada26538097d0f1df510b082bad3411a"}, +] + +[package.dependencies] +aiohappyeyeballs = ">=2.3.0" +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.12.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = true +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = true +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = true +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + +[[package]] +name = "attrs" +version = "24.2.0" +description = "Classes Without Boilerplate" +optional = true +python-versions = ">=3.7" +files = [ + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, +] + +[package.extras] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] + +[[package]] +name = "bitsandbytes" +version = "0.43.3" +description = "k-bit optimizers and matrix multiplication routines." +optional = true +python-versions = "*" +files = [ + {file = "bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:cc99507c352be0715098b2c7577b690dd158972dc4ea10c7495bac104c7c79f0"}, + {file = "bitsandbytes-0.43.3-py3-none-win_amd64.whl", hash = "sha256:257f6552f2144748a84e6c44e1f7a98f3da888f675ed74e18fd7f7eb13c6cafa"}, +] + +[package.dependencies] +numpy = "*" +torch = "*" + +[package.extras] +benchmark = ["matplotlib", "pandas"] +test = ["scipy"] + +[[package]] +name = "certifi" +version = "2024.8.30" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"}, + {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.0" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4f9fc98dad6c2eaa32fc3af1417d95b5e3d08aff968df0cd320066def971f9a6"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0de7b687289d3c1b3e8660d0741874abe7888100efe14bd0f9fd7141bcbda92b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5ed2e36c3e9b4f21dd9422f6893dec0abf2cca553af509b10cd630f878d3eb99"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d3ff7fc90b98c637bda91c89d51264a3dcf210cade3a2c6f838c7268d7a4ca"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1110e22af8ca26b90bd6364fe4c763329b0ebf1ee213ba32b68c73de5752323d"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86f4e8cca779080f66ff4f191a685ced73d2f72d50216f7112185dc02b90b9b7"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f683ddc7eedd742e2889d2bfb96d69573fde1d92fcb811979cdb7165bb9c7d3"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27623ba66c183eca01bf9ff833875b459cad267aeeb044477fedac35e19ba907"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f606a1881d2663630ea5b8ce2efe2111740df4b687bd78b34a8131baa007f79b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0b309d1747110feb25d7ed6b01afdec269c647d382c857ef4663bbe6ad95a912"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:136815f06a3ae311fae551c3df1f998a1ebd01ddd424aa5603a4336997629e95"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:14215b71a762336254351b00ec720a8e85cada43b987da5a042e4ce3e82bd68e"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:79983512b108e4a164b9c8d34de3992f76d48cadc9554c9e60b43f308988aabe"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win32.whl", hash = "sha256:c94057af19bc953643a33581844649a7fdab902624d2eb739738a30e2b3e60fc"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55f56e2ebd4e3bc50442fbc0888c9d8c94e4e06a933804e2af3e89e2f9c1c749"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0d99dd8ff461990f12d6e42c7347fd9ab2532fb70e9621ba520f9e8637161d7c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c57516e58fd17d03ebe67e181a4e4e2ccab1168f8c2976c6a334d4f819fe5944"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dba5d19c4dfab08e58d5b36304b3f92f3bd5d42c1a3fa37b5ba5cdf6dfcbcee"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf4475b82be41b07cc5e5ff94810e6a01f276e37c2d55571e3fe175e467a1a1c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce031db0408e487fd2775d745ce30a7cd2923667cf3b69d48d219f1d8f5ddeb6"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ff4e7cdfdb1ab5698e675ca622e72d58a6fa2a8aa58195de0c0061288e6e3ea"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3710a9751938947e6327ea9f3ea6332a09bf0ba0c09cae9cb1f250bd1f1549bc"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82357d85de703176b5587dbe6ade8ff67f9f69a41c0733cf2425378b49954de5"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47334db71978b23ebcf3c0f9f5ee98b8d65992b65c9c4f2d34c2eaf5bcaf0594"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8ce7fd6767a1cc5a92a639b391891bf1c268b03ec7e021c7d6d902285259685c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f1a2f519ae173b5b6a2c9d5fa3116ce16e48b3462c8b96dfdded11055e3d6365"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:63bc5c4ae26e4bc6be6469943b8253c0fd4e4186c43ad46e713ea61a0ba49129"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bcb4f8ea87d03bc51ad04add8ceaf9b0f085ac045ab4d74e73bbc2dc033f0236"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win32.whl", hash = "sha256:9ae4ef0b3f6b41bad6366fb0ea4fc1d7ed051528e113a60fa2a65a9abb5b1d99"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cee4373f4d3ad28f1ab6290684d8e2ebdb9e7a1b74fdc39e4c211995f77bec27"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0713f3adb9d03d49d365b70b84775d0a0d18e4ab08d12bc46baa6132ba78aaf6"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:de7376c29d95d6719048c194a9cf1a1b0393fbe8488a22008610b0361d834ecf"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a51b48f42d9358460b78725283f04bddaf44a9358197b889657deba38f329db"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b295729485b06c1a0683af02a9e42d2caa9db04a373dc38a6a58cdd1e8abddf1"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee803480535c44e7f5ad00788526da7d85525cfefaf8acf8ab9a310000be4b03"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d59d125ffbd6d552765510e3f31ed75ebac2c7470c7274195b9161a32350284"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cda06946eac330cbe6598f77bb54e690b4ca93f593dee1568ad22b04f347c15"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07afec21bbbbf8a5cc3651aa96b980afe2526e7f048fdfb7f1014d84acc8b6d8"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6b40e8d38afe634559e398cc32b1472f376a4099c75fe6299ae607e404c033b2"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b8dcd239c743aa2f9c22ce674a145e0a25cb1566c495928440a181ca1ccf6719"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:84450ba661fb96e9fd67629b93d2941c871ca86fc38d835d19d4225ff946a631"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44aeb140295a2f0659e113b31cfe92c9061622cadbc9e2a2f7b8ef6b1e29ef4b"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1db4e7fefefd0f548d73e2e2e041f9df5c59e178b4c72fbac4cc6f535cfb1565"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win32.whl", hash = "sha256:5726cf76c982532c1863fb64d8c6dd0e4c90b6ece9feb06c9f202417a31f7dd7"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b197e7094f232959f8f20541ead1d9862ac5ebea1d58e9849c1bf979255dfac9"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dbe03226baf438ac4fda9e2d0715022fd579cb641c4cf639fa40d53b2fe6f3e2"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd9a8bd8900e65504a305bf8ae6fa9fbc66de94178c420791d0293702fce2df7"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8831399554b92b72af5932cdbbd4ddc55c55f631bb13ff8fe4e6536a06c5c51"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a14969b8691f7998e74663b77b4c36c0337cb1df552da83d5c9004a93afdb574"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaf7c1524c0542ee2fc82cc8ec337f7a9f7edee2532421ab200d2b920fc97cf"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425c5f215d0eecee9a56cdb703203dda90423247421bf0d67125add85d0c4455"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:d5b054862739d276e09928de37c79ddeec42a6e1bfc55863be96a36ba22926f6"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:f3e73a4255342d4eb26ef6df01e3962e73aa29baa3124a8e824c5d3364a65748"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:2f6c34da58ea9c1a9515621f4d9ac379871a8f21168ba1b5e09d74250de5ad62"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f09cb5a7bbe1ecae6e87901a2eb23e0256bb524a79ccc53eb0b7629fbe7677c4"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:0099d79bdfcf5c1f0c2c72f91516702ebf8b0b8ddd8905f97a8aecf49712c621"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win32.whl", hash = "sha256:9c98230f5042f4945f957d006edccc2af1e03ed5e37ce7c373f00a5a4daa6149"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62f60aebecfc7f4b82e3f639a7d1433a20ec32824db2199a11ad4f5e146ef5ee"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:af73657b7a68211996527dbfeffbb0864e043d270580c5aef06dc4b659a4b578"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cab5d0b79d987c67f3b9e9c53f54a61360422a5a0bc075f43cab5621d530c3b6"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9289fd5dddcf57bab41d044f1756550f9e7cf0c8e373b8cdf0ce8773dc4bd417"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b493a043635eb376e50eedf7818f2f322eabbaa974e948bd8bdd29eb7ef2a51"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fa2566ca27d67c86569e8c85297aaf413ffab85a8960500f12ea34ff98e4c41"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8e538f46104c815be19c975572d74afb53f29650ea2025bbfaef359d2de2f7f"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fd30dc99682dc2c603c2b315bded2799019cea829f8bf57dc6b61efde6611c8"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2006769bd1640bdf4d5641c69a3d63b71b81445473cac5ded39740a226fa88ab"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc15e99b2d8a656f8e666854404f1ba54765871104e50c8e9813af8a7db07f12"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ab2e5bef076f5a235c3774b4f4028a680432cded7cad37bba0fd90d64b187d19"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:4ec9dd88a5b71abfc74e9df5ebe7921c35cbb3b641181a531ca65cdb5e8e4dea"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:43193c5cda5d612f247172016c4bb71251c784d7a4d9314677186a838ad34858"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:aa693779a8b50cd97570e5a0f343538a8dbd3e496fa5dcb87e29406ad0299654"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win32.whl", hash = "sha256:7706f5850360ac01d80c89bcef1640683cc12ed87f42579dab6c5d3ed6888613"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:c3e446d253bd88f6377260d07c895816ebf33ffffd56c1c792b13bff9c3e1ade"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:980b4f289d1d90ca5efcf07958d3eb38ed9c0b7676bf2831a54d4f66f9c27dfa"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f28f891ccd15c514a0981f3b9db9aa23d62fe1a99997512b0491d2ed323d229a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8aacce6e2e1edcb6ac625fb0f8c3a9570ccc7bfba1f63419b3769ccf6a00ed0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7af3717683bea4c87acd8c0d3d5b44d56120b26fd3f8a692bdd2d5260c620a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff2ed8194587faf56555927b3aa10e6fb69d931e33953943bc4f837dfee2242"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e91f541a85298cf35433bf66f3fab2a4a2cff05c127eeca4af174f6d497f0d4b"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309a7de0a0ff3040acaebb35ec45d18db4b28232f21998851cfa709eeff49d62"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:285e96d9d53422efc0d7a17c60e59f37fbf3dfa942073f666db4ac71e8d726d0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5d447056e2ca60382d460a604b6302d8db69476fd2015c81e7c35417cfabe4cd"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:20587d20f557fe189b7947d8e7ec5afa110ccf72a3128d61a2a387c3313f46be"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:130272c698667a982a5d0e626851ceff662565379baf0ff2cc58067b81d4f11d"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ab22fbd9765e6954bc0bcff24c25ff71dcbfdb185fcdaca49e81bac68fe724d3"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7782afc9b6b42200f7362858f9e73b1f8316afb276d316336c0ec3bd73312742"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win32.whl", hash = "sha256:2de62e8801ddfff069cd5c504ce3bc9672b23266597d4e4f50eda28846c322f2"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:95c3c157765b031331dd4db3c775e58deaee050a3042fcad72cbc4189d7c8dca"}, + {file = "charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079"}, + {file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"}, +] + +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "cloudpickle" +version = "3.1.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = true +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e"}, + {file = "cloudpickle-3.1.0.tar.gz", hash = "sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "datasets" +version = "2.14.4" +description = "HuggingFace community-driven open-source library of datasets" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, + {file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, +] + +[package.dependencies] +aiohttp = "*" +dill = ">=0.3.0,<0.3.8" +fsspec = {version = ">=2021.11.1", extras = ["http"]} +huggingface-hub = ">=0.14.0,<1.0.0" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +pyarrow = ">=8.0.0" +pyyaml = ">=5.1" +requests = ">=2.19.0" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] +audio = ["librosa", "soundfile (>=0.12.1)"] +benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] +jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] +quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] +s3 = ["s3fs"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +torch = ["torch"] +vision = ["Pillow (>=6.2.1)"] + +[[package]] +name = "deprecated" +version = "1.2.14" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, + {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"}, +] + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] + +[[package]] +name = "dill" +version = "0.3.7" +description = "serialize all of Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, + {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] + +[[package]] +name = "diskcache" +version = "5.6.3" +description = "Disk Cache -- Disk and file backed persistent cache." +optional = true +python-versions = ">=3" +files = [ + {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"}, + {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, +] + +[[package]] +name = "einops" +version = "0.6.1" +description = "A new flavour of deep learning operations" +optional = false +python-versions = ">=3.7" +files = [ + {file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"}, + {file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.16.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +typing = ["typing-extensions (>=4.12.2)"] + +[[package]] +name = "frozenlist" +version = "1.4.1" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = true +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + +[[package]] +name = "fsspec" +version = "2024.10.0" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2024.10.0-py3-none-any.whl", hash = "sha256:03b9a6785766a4de40368b88906366755e2819e758b83705c88cd7cb5fe81871"}, + {file = "fsspec-2024.10.0.tar.gz", hash = "sha256:eda2d8a4116d4f2429db8550f2457da57279247dd930bb12f821b58391359493"}, +] + +[package.dependencies] +aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +dev = ["pre-commit", "ruff"] +doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] +test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] +test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] +tqdm = ["tqdm"] + +[[package]] +name = "googleapis-common-protos" +version = "1.65.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis_common_protos-1.65.0-py2.py3-none-any.whl", hash = "sha256:2972e6c496f435b92590fd54045060867f3fe9be2c82ab148fc8885035479a63"}, + {file = "googleapis_common_protos-1.65.0.tar.gz", hash = "sha256:334a29d07cddc3aa01dee4988f9afd9b2916ee2ff49d6b757155dc0d197852c0"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + +[[package]] +name = "grpc-interceptor" +version = "0.15.4" +description = "Simplifies gRPC interceptors" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "grpc-interceptor-0.15.4.tar.gz", hash = "sha256:1f45c0bcb58b6f332f37c637632247c9b02bc6af0fdceb7ba7ce8d2ebbfb0926"}, + {file = "grpc_interceptor-0.15.4-py3-none-any.whl", hash = "sha256:0035f33228693ed3767ee49d937bac424318db173fef4d2d0170b3215f254d9d"}, +] + +[package.dependencies] +grpcio = ">=1.49.1,<2.0.0" + +[package.extras] +testing = ["protobuf (>=4.21.9)"] + +[[package]] +name = "grpcio" +version = "1.67.0" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.67.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:bd79929b3bb96b54df1296cd3bf4d2b770bd1df6c2bdf549b49bab286b925cdc"}, + {file = "grpcio-1.67.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:16724ffc956ea42967f5758c2f043faef43cb7e48a51948ab593570570d1e68b"}, + {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:2b7183c80b602b0ad816315d66f2fb7887614ead950416d60913a9a71c12560d"}, + {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:efe32b45dd6d118f5ea2e5deaed417d8a14976325c93812dd831908522b402c9"}, + {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe89295219b9c9e47780a0f1c75ca44211e706d1c598242249fe717af3385ec8"}, + {file = "grpcio-1.67.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa8d025fae1595a207b4e47c2e087cb88d47008494db258ac561c00877d4c8f8"}, + {file = "grpcio-1.67.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f95e15db43e75a534420e04822df91f645664bf4ad21dfaad7d51773c80e6bb4"}, + {file = "grpcio-1.67.0-cp310-cp310-win32.whl", hash = "sha256:a6b9a5c18863fd4b6624a42e2712103fb0f57799a3b29651c0e5b8119a519d65"}, + {file = "grpcio-1.67.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6eb68493a05d38b426604e1dc93bfc0137c4157f7ab4fac5771fd9a104bbaa6"}, + {file = "grpcio-1.67.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:e91d154689639932305b6ea6f45c6e46bb51ecc8ea77c10ef25aa77f75443ad4"}, + {file = "grpcio-1.67.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cb204a742997277da678611a809a8409657b1398aaeebf73b3d9563b7d154c13"}, + {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:ae6de510f670137e755eb2a74b04d1041e7210af2444103c8c95f193340d17ee"}, + {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74b900566bdf68241118f2918d312d3bf554b2ce0b12b90178091ea7d0a17b3d"}, + {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e95e43447a02aa603abcc6b5e727d093d161a869c83b073f50b9390ecf0fa8"}, + {file = "grpcio-1.67.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0bb94e66cd8f0baf29bd3184b6aa09aeb1a660f9ec3d85da615c5003154bc2bf"}, + {file = "grpcio-1.67.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:82e5bd4b67b17c8c597273663794a6a46a45e44165b960517fe6d8a2f7f16d23"}, + {file = "grpcio-1.67.0-cp311-cp311-win32.whl", hash = "sha256:7fc1d2b9fd549264ae585026b266ac2db53735510a207381be509c315b4af4e8"}, + {file = "grpcio-1.67.0-cp311-cp311-win_amd64.whl", hash = "sha256:ac11ecb34a86b831239cc38245403a8de25037b448464f95c3315819e7519772"}, + {file = "grpcio-1.67.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:227316b5631260e0bef8a3ce04fa7db4cc81756fea1258b007950b6efc90c05d"}, + {file = "grpcio-1.67.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d90cfdafcf4b45a7a076e3e2a58e7bc3d59c698c4f6470b0bb13a4d869cf2273"}, + {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:77196216d5dd6f99af1c51e235af2dd339159f657280e65ce7e12c1a8feffd1d"}, + {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15c05a26a0f7047f720da41dc49406b395c1470eef44ff7e2c506a47ac2c0591"}, + {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3840994689cc8cbb73d60485c594424ad8adb56c71a30d8948d6453083624b52"}, + {file = "grpcio-1.67.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5a1e03c3102b6451028d5dc9f8591131d6ab3c8a0e023d94c28cb930ed4b5f81"}, + {file = "grpcio-1.67.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:682968427a63d898759474e3b3178d42546e878fdce034fd7474ef75143b64e3"}, + {file = "grpcio-1.67.0-cp312-cp312-win32.whl", hash = "sha256:d01793653248f49cf47e5695e0a79805b1d9d4eacef85b310118ba1dfcd1b955"}, + {file = "grpcio-1.67.0-cp312-cp312-win_amd64.whl", hash = "sha256:985b2686f786f3e20326c4367eebdaed3e7aa65848260ff0c6644f817042cb15"}, + {file = "grpcio-1.67.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:8c9a35b8bc50db35ab8e3e02a4f2a35cfba46c8705c3911c34ce343bd777813a"}, + {file = "grpcio-1.67.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:42199e704095b62688998c2d84c89e59a26a7d5d32eed86d43dc90e7a3bd04aa"}, + {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:c4c425f440fb81f8d0237c07b9322fc0fb6ee2b29fbef5f62a322ff8fcce240d"}, + {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:323741b6699cd2b04a71cb38f502db98f90532e8a40cb675393d248126a268af"}, + {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:662c8e105c5e5cee0317d500eb186ed7a93229586e431c1bf0c9236c2407352c"}, + {file = "grpcio-1.67.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:f6bd2ab135c64a4d1e9e44679a616c9bc944547357c830fafea5c3caa3de5153"}, + {file = "grpcio-1.67.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:2f55c1e0e2ae9bdd23b3c63459ee4c06d223b68aeb1961d83c48fb63dc29bc03"}, + {file = "grpcio-1.67.0-cp313-cp313-win32.whl", hash = "sha256:fd6bc27861e460fe28e94226e3673d46e294ca4673d46b224428d197c5935e69"}, + {file = "grpcio-1.67.0-cp313-cp313-win_amd64.whl", hash = "sha256:cf51d28063338608cd8d3cd64677e922134837902b70ce00dad7f116e3998210"}, + {file = "grpcio-1.67.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:7f200aca719c1c5dc72ab68be3479b9dafccdf03df530d137632c534bb6f1ee3"}, + {file = "grpcio-1.67.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0892dd200ece4822d72dd0952f7112c542a487fc48fe77568deaaa399c1e717d"}, + {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:f4d613fbf868b2e2444f490d18af472ccb47660ea3df52f068c9c8801e1f3e85"}, + {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c69bf11894cad9da00047f46584d5758d6ebc9b5950c0dc96fec7e0bce5cde9"}, + {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9bca3ca0c5e74dea44bf57d27e15a3a3996ce7e5780d61b7c72386356d231db"}, + {file = "grpcio-1.67.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:014dfc020e28a0d9be7e93a91f85ff9f4a87158b7df9952fe23cc42d29d31e1e"}, + {file = "grpcio-1.67.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d4ea4509d42c6797539e9ec7496c15473177ce9abc89bc5c71e7abe50fc25737"}, + {file = "grpcio-1.67.0-cp38-cp38-win32.whl", hash = "sha256:9d75641a2fca9ae1ae86454fd25d4c298ea8cc195dbc962852234d54a07060ad"}, + {file = "grpcio-1.67.0-cp38-cp38-win_amd64.whl", hash = "sha256:cff8e54d6a463883cda2fab94d2062aad2f5edd7f06ae3ed030f2a74756db365"}, + {file = "grpcio-1.67.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:62492bd534979e6d7127b8a6b29093161a742dee3875873e01964049d5250a74"}, + {file = "grpcio-1.67.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eef1dce9d1a46119fd09f9a992cf6ab9d9178b696382439446ca5f399d7b96fe"}, + {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f623c57a5321461c84498a99dddf9d13dac0e40ee056d884d6ec4ebcab647a78"}, + {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54d16383044e681f8beb50f905249e4e7261dd169d4aaf6e52eab67b01cbbbe2"}, + {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2a44e572fb762c668e4812156b81835f7aba8a721b027e2d4bb29fb50ff4d33"}, + {file = "grpcio-1.67.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:391df8b0faac84d42f5b8dfc65f5152c48ed914e13c522fd05f2aca211f8bfad"}, + {file = "grpcio-1.67.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfd9306511fdfc623a1ba1dc3bc07fbd24e6cfbe3c28b4d1e05177baa2f99617"}, + {file = "grpcio-1.67.0-cp39-cp39-win32.whl", hash = "sha256:30d47dbacfd20cbd0c8be9bfa52fdb833b395d4ec32fe5cff7220afc05d08571"}, + {file = "grpcio-1.67.0-cp39-cp39-win_amd64.whl", hash = "sha256:f55f077685f61f0fbd06ea355142b71e47e4a26d2d678b3ba27248abfe67163a"}, + {file = "grpcio-1.67.0.tar.gz", hash = "sha256:e090b2553e0da1c875449c8e75073dd4415dd71c9bde6a406240fdf4c0ee467c"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.67.0)"] + +[[package]] +name = "grpcio-reflection" +version = "1.62.3" +description = "Standard Protobuf Reflection Service for gRPC" +optional = false +python-versions = ">=3.6" +files = [ + {file = "grpcio-reflection-1.62.3.tar.gz", hash = "sha256:cb84682933c400bddf94dd94f928d1c6570f500b6dd255973d4bfb495b82585f"}, + {file = "grpcio_reflection-1.62.3-py3-none-any.whl", hash = "sha256:a48ef37df81a3bada78261fc92ef382f061112f989d1312398b945cc69838b9c"}, +] + +[package.dependencies] +grpcio = ">=1.62.3" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-status" +version = "1.62.3" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.6" +files = [ + {file = "grpcio-status-1.62.3.tar.gz", hash = "sha256:289bdd7b2459794a12cf95dc0cb727bd4a1742c37bd823f760236c937e53a485"}, + {file = "grpcio_status-1.62.3-py3-none-any.whl", hash = "sha256:f9049b762ba8de6b1086789d8315846e094edac2c50beaf462338b301a8fd4b8"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.62.3" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-tools" +version = "1.62.3" +description = "Protobuf code generator for gRPC" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-tools-1.62.3.tar.gz", hash = "sha256:7c7136015c3d62c3eef493efabaf9e3380e3e66d24ee8e94c01cb71377f57833"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2f968b049c2849540751ec2100ab05e8086c24bead769ca734fdab58698408c1"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:0a8c0c4724ae9c2181b7dbc9b186df46e4f62cb18dc184e46d06c0ebeccf569e"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5782883a27d3fae8c425b29a9d3dcf5f47d992848a1b76970da3b5a28d424b26"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3d812daffd0c2d2794756bd45a353f89e55dc8f91eb2fc840c51b9f6be62667"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b47d0dda1bdb0a0ba7a9a6de88e5a1ed61f07fad613964879954961e36d49193"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ca246dffeca0498be9b4e1ee169b62e64694b0f92e6d0be2573e65522f39eea9"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-win32.whl", hash = "sha256:6a56d344b0bab30bf342a67e33d386b0b3c4e65868ffe93c341c51e1a8853ca5"}, + {file = "grpcio_tools-1.62.3-cp310-cp310-win_amd64.whl", hash = "sha256:710fecf6a171dcbfa263a0a3e7070e0df65ba73158d4c539cec50978f11dad5d"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:703f46e0012af83a36082b5f30341113474ed0d91e36640da713355cd0ea5d23"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:7cc83023acd8bc72cf74c2edbe85b52098501d5b74d8377bfa06f3e929803492"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ff7d58a45b75df67d25f8f144936a3e44aabd91afec833ee06826bd02b7fbe7"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f2483ea232bd72d98a6dc6d7aefd97e5bc80b15cd909b9e356d6f3e326b6e43"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:962c84b4da0f3b14b3cdb10bc3837ebc5f136b67d919aea8d7bb3fd3df39528a"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8ad0473af5544f89fc5a1ece8676dd03bdf160fb3230f967e05d0f4bf89620e3"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-win32.whl", hash = "sha256:db3bc9fa39afc5e4e2767da4459df82b095ef0cab2f257707be06c44a1c2c3e5"}, + {file = "grpcio_tools-1.62.3-cp311-cp311-win_amd64.whl", hash = "sha256:e0898d412a434e768a0c7e365acabe13ff1558b767e400936e26b5b6ed1ee51f"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d102b9b21c4e1e40af9a2ab3c6d41afba6bd29c0aa50ca013bf85c99cdc44ac5"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:0a52cc9444df978438b8d2332c0ca99000521895229934a59f94f37ed896b133"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141d028bf5762d4a97f981c501da873589df3f7e02f4c1260e1921e565b376fa"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47a5c093ab256dec5714a7a345f8cc89315cb57c298b276fa244f37a0ba507f0"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:f6831fdec2b853c9daa3358535c55eed3694325889aa714070528cf8f92d7d6d"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e02d7c1a02e3814c94ba0cfe43d93e872c758bd8fd5c2797f894d0c49b4a1dfc"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-win32.whl", hash = "sha256:b881fd9505a84457e9f7e99362eeedd86497b659030cf57c6f0070df6d9c2b9b"}, + {file = "grpcio_tools-1.62.3-cp312-cp312-win_amd64.whl", hash = "sha256:11c625eebefd1fd40a228fc8bae385e448c7e32a6ae134e43cf13bbc23f902b7"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:ec6fbded0c61afe6f84e3c2a43e6d656791d95747d6d28b73eff1af64108c434"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:bfda6ee8990997a9df95c5606f3096dae65f09af7ca03a1e9ca28f088caca5cf"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b77f9f9cee87cd798f0fe26b7024344d1b03a7cd2d2cba7035f8433b13986325"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e02d3b96f2d0e4bab9ceaa30f37d4f75571e40c6272e95364bff3125a64d184"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1da38070738da53556a4b35ab67c1b9884a5dd48fa2f243db35dc14079ea3d0c"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ace43b26d88a58dcff16c20d23ff72b04d0a415f64d2820f4ff06b1166f50557"}, + {file = "grpcio_tools-1.62.3-cp37-cp37m-win_amd64.whl", hash = "sha256:350a80485e302daaa95d335a931f97b693e170e02d43767ab06552c708808950"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:c3a1ac9d394f8e229eb28eec2e04b9a6f5433fa19c9d32f1cb6066e3c5114a1d"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:11f363570dea661dde99e04a51bd108a5807b5df32a6f8bdf4860e34e94a4dbf"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9ad9950119d8ae27634e68b7663cc8d340ae535a0f80d85a55e56a6973ab1f"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c5d22b252dcef11dd1e0fbbe5bbfb9b4ae048e8880d33338215e8ccbdb03edc"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:27cd9ef5c5d68d5ed104b6dcb96fe9c66b82050e546c9e255716903c3d8f0373"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f4b1615adf67bd8bb71f3464146a6f9949972d06d21a4f5e87e73f6464d97f57"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-win32.whl", hash = "sha256:e18e15287c31baf574fcdf8251fb7f997d64e96c6ecf467906e576da0a079af6"}, + {file = "grpcio_tools-1.62.3-cp38-cp38-win_amd64.whl", hash = "sha256:6c3064610826f50bd69410c63101954676edc703e03f9e8f978a135f1aaf97c1"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:8e62cc7164b0b7c5128e637e394eb2ef3db0e61fc798e80c301de3b2379203ed"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:c8ad5cce554e2fcaf8842dee5d9462583b601a3a78f8b76a153c38c963f58c10"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec279dcf3518201fc592c65002754f58a6b542798cd7f3ecd4af086422f33f29"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c989246c2aebc13253f08be32538a4039a64e12d9c18f6d662d7aee641dc8b5"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ca4f5eeadbb57cf03317d6a2857823239a63a59cc935f5bd6cf6e8b7af7a7ecc"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0cb3a3436ac119cbd37a7d3331d9bdf85dad21a6ac233a3411dff716dcbf401e"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-win32.whl", hash = "sha256:3eae6ea76d62fcac091e1f15c2dcedf1dc3f114f8df1a972a8a0745e89f4cf61"}, + {file = "grpcio_tools-1.62.3-cp39-cp39-win_amd64.whl", hash = "sha256:eec73a005443061f4759b71a056f745e3b000dc0dc125c9f20560232dfbcbd14"}, +] + +[package.dependencies] +grpcio = ">=1.62.3" +protobuf = ">=4.21.6,<5.0dev" +setuptools = "*" + +[[package]] +name = "hf-transfer" +version = "0.1.8" +description = "Speed up file transfers with the Hugging Face Hub." +optional = false +python-versions = ">=3.7" +files = [ + {file = "hf_transfer-0.1.8-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:70858f9e94286738ed300484a45beb5cfee6a7ddac4c5886f9c6fce7823ac5ab"}, + {file = "hf_transfer-0.1.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:38adc73f0a8526319d90f7cc5dc2d5e4bb66f487a513d94b98aa6725be732e4a"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44d2f0c08198d8d899fe9d66e86aee2dd844bd7ce33888f261373fcec81d2a54"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1de2a4ef36f9e60b3d3bec00193c0aafd75771709f2ca51b9b162373f5af3d32"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e319269e3606a5ff2979296841766649ac73598a4a8eee2a968f86c8071fea5a"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f6026cf3be6a53ea42f92172f60c1c0675baaa9073f865e671b661dde5fd157"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f865c33ada5bd3650c2b46e59979f2d7755c3f517f8d0facc78576a0c7d26406"}, + {file = "hf_transfer-0.1.8-cp310-none-win32.whl", hash = "sha256:2054730e8d8ed21917c64be7199e06424b2bd08df1c43a72766afaed7992f2d3"}, + {file = "hf_transfer-0.1.8-cp310-none-win_amd64.whl", hash = "sha256:2b4f1a9446ba31170b5b1eca4e916504d18378a6b5fe959896bdac8a736a5ecb"}, + {file = "hf_transfer-0.1.8-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:e27c15fcc5869ad7e52bbc0bdec6106b288d1c463f8d2da92f28615a3b181361"}, + {file = "hf_transfer-0.1.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:871a0032d011ebc6409a73a8406b98b84ff2cd3ed7d9e1af8cdf4d660b9fab9b"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:686fa756e1e0214bb6327d33c66732c52274d94a8460beb50604ad988b391cf6"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:36a03b1b2911b0cf15b1b9d971a34b32dadcc4f2fd979aaff5979d6ce4017c34"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:079db90c81f41f4cf3227dfaaa855a9b8e9aef45bc7c2be29ce7232cd83ff881"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac08a4524127fdd14c234d4bcbe49d1c498acf5335c781714823179bcc8dc039"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837432e73cb17274a6782b6216e8ce058aa325a475dc44a5a6a753d48b86d18a"}, + {file = "hf_transfer-0.1.8-cp311-none-win32.whl", hash = "sha256:b180f9823dde35aba9bc0f1d0c04ac8a873baebd3732a7ffe4f11940abc7df0d"}, + {file = "hf_transfer-0.1.8-cp311-none-win_amd64.whl", hash = "sha256:37907d2135cebcf8b6d419bb575148d89c224f16b69357f027bd29d0e85c6529"}, + {file = "hf_transfer-0.1.8-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baf948f4f493949309cbe60529620b9b0aef854a22b6e526753364acc57c09b6"}, + {file = "hf_transfer-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bce5c8bdefa478c5d5eaa646cc4ce1df5cfe764d98572ad0c6b8773e98d49f6"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54d6f8a1a86128d651a3799e1267c343d60f81f2c565d7c5416eb8e674e4cf0e"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f79fd1b0c2ed93efb4c5f684118d7a762ecdd218e170df8208c4e13d3dcd4959"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:414df35692670683bf5623498ef9d88a8df5d77e9516515da6e2b34d1054c11f"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c9798d5f951f66b96d40a7a53910260cb5874fda56cf5944dddb7c571f37ec3"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:060c661691f85a61392e57579c80eb64b5ee277434e81fb582f605c1c8ff05d5"}, + {file = "hf_transfer-0.1.8-cp312-none-win32.whl", hash = "sha256:f7840e32379820c3e1571a480238e05ea043e970c99d2e999578004a2eb17788"}, + {file = "hf_transfer-0.1.8-cp312-none-win_amd64.whl", hash = "sha256:9a3204ec423cc5e659872e8179f8704ad9ce2abb1e6a991f8838aedf1dc07830"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09949e86ad63ee139e463fd0dfaf401515ae70445854199f61d545514c65f744"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf1a74552845b93ea972e6e7131ef54e56056aa54137e93a40faf3fbcb2442ff"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959bcb3afb4ee6f2a07031a947dba98ec0b64c001bc914fbd8fc32e13a287162"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e01eecdb8162bd61dab9090fbd9f8034dd8b5755ef727a21ca8a057f80cb91ee"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50650a38e9d31f5ad8f010e4598bf304ecd99c17162e7d93f67e031571b864ee"}, + {file = "hf_transfer-0.1.8-cp37-none-win32.whl", hash = "sha256:e29b9d1d378138f2f4eae0e93ca94af3b5d45f4532eef69f1ab97fe06f9c9d9e"}, + {file = "hf_transfer-0.1.8-cp37-none-win_amd64.whl", hash = "sha256:cfd6cef43ae883103117a371f8ebae4e7f9637bc6fb480f1be5568e2fe22a8a7"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92a68f7a0043cca8a0de4decc760dca177530944cbab502afac503bd1b2fa01a"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e3138e408179f80a5480598e32f8e1abb564915cbde4d3bc8da52811c75dc3ea"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4544d148930ad34442d43b8fa911c8479c04a95b858b1d1f91e0b7da77082fad"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a851794b9f029965664f8c3002c957fccf21685e9397ceb4f9f19c986dee8ad3"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:791aaf87c5319ac83edb6ab2994b3db19924c49d6ff667dd3d8a610b455ff70a"}, + {file = "hf_transfer-0.1.8-cp38-none-win32.whl", hash = "sha256:8f71e5d35d3a3160dcca12fdcc8119033aeacaa6a32838a7ad9f9cb1008bbe58"}, + {file = "hf_transfer-0.1.8-cp38-none-win_amd64.whl", hash = "sha256:543287b4ceb1e25501580b99690f7f0df9d3631d29306f37cbd97e918c732944"}, + {file = "hf_transfer-0.1.8-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:7ce02a18bd0bb2343e707ac85b68c946bc37623ee24150c69158f6b2b2c7a98f"}, + {file = "hf_transfer-0.1.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64d7f8dbd64ba183ed1df75d47c84e075ff666ceaa335bff1de16b09eaac5b80"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e7858694e11419ae27e542fb8fc0d0e54d46ff7768fe73bc359d70b8f5aa578"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed116cd9d1edfa32c0136d7cb8e5f1afd2b32df43c49085d428f108fc8e1c8f"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e385d0da9c6b3472ab29285d2d46c9f9903205b8d108f88a82f3f85aafae0ab"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98f75fa4b86ef15433cd907807ac77d1fb39d7e7b790bfd39c7ae9c385bf0200"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1a63ad947d2901425ac0a3ed70c3696dfde27fadb0482ed763bdd5cc946b278"}, + {file = "hf_transfer-0.1.8-cp39-none-win32.whl", hash = "sha256:3e74096915813ae842ea6a5bdf10c0fef960aa51a35a560955b3e61cdfe3db57"}, + {file = "hf_transfer-0.1.8-cp39-none-win_amd64.whl", hash = "sha256:05ea16307bf4a5eb097cbc6e5057e4eb5e080a138af23ef639fd38857723c288"}, + {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:928ff036c3e98e10dcfbdb4fcdfc4592d37a5cc8e365a7ba8dfd4337e849d675"}, + {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d49ba3ce67035f460ae1924fe2feafec155cb535eec7f31ed5109c19064cd294"}, + {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b01f5872c62cfee3ec9ca5c738818296f69f8adf84b4d8d15f2a5601d9dda339"}, + {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:659d4212d50847a5165666bf43d67727679b4f694ef9c413613cc27093136527"}, + {file = "hf_transfer-0.1.8.tar.gz", hash = "sha256:26d229468152e7a3ec12664cac86b8c2800695fd85f9c9a96677a775cc04f0b3"}, +] + +[[package]] +name = "huggingface-hub" +version = "0.26.1" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "huggingface_hub-0.26.1-py3-none-any.whl", hash = "sha256:5927a8fc64ae68859cd954b7cc29d1c8390a5e15caba6d3d349c973be8fdacf3"}, + {file = "huggingface_hub-0.26.1.tar.gz", hash = "sha256:414c0d9b769eecc86c70f9d939d0f48bb28e8461dd1130021542eff0212db890"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +hf-transfer = ["hf-transfer (>=0.1.4)"] +inference = ["aiohttp"] +quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.5.0)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors[torch]", "torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + +[[package]] +name = "idna" +version = "3.10" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.6" +files = [ + {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, + {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, +] + +[package.extras] +all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] + +[[package]] +name = "importlib-metadata" +version = "8.5.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b"}, + {file = "importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7"}, +] + +[package.dependencies] +zipp = ">=3.20" + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +perf = ["ipython"] +test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] +type = ["pytest-mypy"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "interegular" +version = "0.3.3" +description = "a regex intersection checker" +optional = true +python-versions = ">=3.7" +files = [ + {file = "interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c"}, + {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"}, +] + +[[package]] +name = "jinja2" +version = "3.1.4" +description = "A very fast and expressive template engine." +optional = true +python-versions = ">=3.7" +files = [ + {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, + {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = true +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + +[[package]] +name = "jsonschema" +version = "4.23.0" +description = "An implementation of JSON Schema validation for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] + +[[package]] +name = "jsonschema-specifications" +version = "2024.10.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = true +python-versions = ">=3.9" +files = [ + {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"}, + {file = "jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + +[[package]] +name = "lark" +version = "1.2.2" +description = "a modern parsing library" +optional = true +python-versions = ">=3.8" +files = [ + {file = "lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c"}, + {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"}, +] + +[package.extras] +atomic-cache = ["atomicwrites"] +interegular = ["interegular (>=0.3.1,<0.4.0)"] +nearley = ["js2py"] +regex = ["regex"] + +[[package]] +name = "llvmlite" +version = "0.43.0" +description = "lightweight wrapper around basic LLVM functionality" +optional = true +python-versions = ">=3.9" +files = [ + {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, + {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"}, + {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"}, + {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"}, + {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"}, + {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"}, + {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"}, +] + +[[package]] +name = "loguru" +version = "0.6.0" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ + {file = "loguru-0.6.0-py3-none-any.whl", hash = "sha256:4e2414d534a2ab57573365b3e6d0234dfb1d84b68b7f3b948e6fb743860a77c3"}, + {file = "loguru-0.6.0.tar.gz", hash = "sha256:066bd06758d0a513e9836fd9c6b5a75bfb3fd36841f4b996bc60b547a309d41c"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] + +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + +[[package]] +name = "markupsafe" +version = "3.0.2" +description = "Safely add untrusted strings to HTML/XML markup." +optional = false +python-versions = ">=3.9" +files = [ + {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-win32.whl", hash = "sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a"}, + {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, +] + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" + +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + +[[package]] +name = "moe-kernels" +version = "0.4.0" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" + +[[package]] +name = "moe-kernels" +version = "0.4.0" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" + +[[package]] +name = "moe-kernels" +version = "0.4.0" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" + +[[package]] +name = "moe-kernels" +version = "0.4.0" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" + +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = true +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + +[[package]] +name = "multidict" +version = "6.1.0" +description = "multidict implementation" +optional = false +python-versions = ">=3.8" +files = [ + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a114d03b938376557927ab23f1e950827c3b893ccb94b62fd95d430fd0e5cf53"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1c416351ee6271b2f49b56ad7f308072f6f44b37118d69c2cad94f3fa8a40d5"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b5d83030255983181005e6cfbac1617ce9746b219bc2aad52201ad121226581"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e97b5e938051226dc025ec80980c285b053ffb1e25a3db2a3aa3bc046bf7f56"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d618649d4e70ac6efcbba75be98b26ef5078faad23592f9b51ca492953012429"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10524ebd769727ac77ef2278390fb0068d83f3acb7773792a5080f2b0abf7748"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ff3827aef427c89a25cc96ded1759271a93603aba9fb977a6d264648ebf989db"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:06809f4f0f7ab7ea2cabf9caca7d79c22c0758b58a71f9d32943ae13c7ace056"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f179dee3b863ab1c59580ff60f9d99f632f34ccb38bf67a33ec6b3ecadd0fd76"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:aaed8b0562be4a0876ee3b6946f6869b7bcdb571a5d1496683505944e268b160"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3c8b88a2ccf5493b6c8da9076fb151ba106960a2df90c2633f342f120751a9e7"}, + {file = "multidict-6.1.0-cp310-cp310-win32.whl", hash = "sha256:4a9cb68166a34117d6646c0023c7b759bf197bee5ad4272f420a0141d7eb03a0"}, + {file = "multidict-6.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:20b9b5fbe0b88d0bdef2012ef7dee867f874b72528cf1d08f1d59b0e3850129d"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3efe2c2cb5763f2f1b275ad2bf7a287d3f7ebbef35648a9726e3b69284a4f3d6"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7053d3b0353a8b9de430a4f4b4268ac9a4fb3481af37dfe49825bf45ca24156"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27e5fc84ccef8dfaabb09d82b7d179c7cf1a3fbc8a966f8274fcb4ab2eb4cadb"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2b90b43e696f25c62656389d32236e049568b39320e2735d51f08fd362761b"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d83a047959d38a7ff552ff94be767b7fd79b831ad1cd9920662db05fec24fe72"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1a9dd711d0877a1ece3d2e4fea11a8e75741ca21954c919406b44e7cf971304"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec2abea24d98246b94913b76a125e855eb5c434f7c46546046372fe60f666351"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4867cafcbc6585e4b678876c489b9273b13e9fff9f6d6d66add5e15d11d926cb"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5b48204e8d955c47c55b72779802b219a39acc3ee3d0116d5080c388970b76e3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8fff389528cad1618fb4b26b95550327495462cd745d879a8c7c2115248e399"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a7a9541cd308eed5e30318430a9c74d2132e9a8cb46b901326272d780bf2d423"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:da1758c76f50c39a2efd5e9859ce7d776317eb1dd34317c8152ac9251fc574a3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c943a53e9186688b45b323602298ab727d8865d8c9ee0b17f8d62d14b56f0753"}, + {file = "multidict-6.1.0-cp311-cp311-win32.whl", hash = "sha256:90f8717cb649eea3504091e640a1b8568faad18bd4b9fcd692853a04475a4b80"}, + {file = "multidict-6.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:82176036e65644a6cc5bd619f65f6f19781e8ec2e5330f51aa9ada7504cc1926"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b04772ed465fa3cc947db808fa306d79b43e896beb677a56fb2347ca1a49c1fa"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6180c0ae073bddeb5a97a38c03f30c233e0a4d39cd86166251617d1bbd0af436"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:071120490b47aa997cca00666923a83f02c7fbb44f71cf7f136df753f7fa8761"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50b3a2710631848991d0bf7de077502e8994c804bb805aeb2925a981de58ec2e"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b58c621844d55e71c1b7f7c498ce5aa6985d743a1a59034c57a905b3f153c1ef"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b6d90641869892caa9ca42ff913f7ff1c5ece06474fbd32fb2cf6834726c95"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b820514bfc0b98a30e3d85462084779900347e4d49267f747ff54060cc33925"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10a9b09aba0c5b48c53761b7c720aaaf7cf236d5fe394cd399c7ba662d5f9966"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e16bf3e5fc9f44632affb159d30a437bfe286ce9e02754759be5536b169b305"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:76f364861c3bfc98cbbcbd402d83454ed9e01a5224bb3a28bf70002a230f73e2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:820c661588bd01a0aa62a1283f20d2be4281b086f80dad9e955e690c75fb54a2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:0e5f362e895bc5b9e67fe6e4ded2492d8124bdf817827f33c5b46c2fe3ffaca6"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ec660d19bbc671e3a6443325f07263be452c453ac9e512f5eb935e7d4ac28b3"}, + {file = "multidict-6.1.0-cp312-cp312-win32.whl", hash = "sha256:58130ecf8f7b8112cdb841486404f1282b9c86ccb30d3519faf301b2e5659133"}, + {file = "multidict-6.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:188215fc0aafb8e03341995e7c4797860181562380f81ed0a87ff455b70bf1f1"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d569388c381b24671589335a3be6e1d45546c2988c2ebe30fdcada8457a31008"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:052e10d2d37810b99cc170b785945421141bf7bb7d2f8799d431e7db229c385f"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f90c822a402cb865e396a504f9fc8173ef34212a342d92e362ca498cad308e28"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b225d95519a5bf73860323e633a664b0d85ad3d5bede6d30d95b35d4dfe8805b"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:23bfd518810af7de1116313ebd9092cb9aa629beb12f6ed631ad53356ed6b86c"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c09fcfdccdd0b57867577b719c69e347a436b86cd83747f179dbf0cc0d4c1f3"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf6bea52ec97e95560af5ae576bdac3aa3aae0b6758c6efa115236d9e07dae44"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57feec87371dbb3520da6192213c7d6fc892d5589a93db548331954de8248fd2"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0c3f390dc53279cbc8ba976e5f8035eab997829066756d811616b652b00a23a3"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:59bfeae4b25ec05b34f1956eaa1cb38032282cd4dfabc5056d0a1ec4d696d3aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b2f59caeaf7632cc633b5cf6fc449372b83bbdf0da4ae04d5be36118e46cc0aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:37bb93b2178e02b7b618893990941900fd25b6b9ac0fa49931a40aecdf083fe4"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4e9f48f58c2c523d5a06faea47866cd35b32655c46b443f163d08c6d0ddb17d6"}, + {file = "multidict-6.1.0-cp313-cp313-win32.whl", hash = "sha256:3a37ffb35399029b45c6cc33640a92bef403c9fd388acce75cdc88f58bd19a81"}, + {file = "multidict-6.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:e9aa71e15d9d9beaad2c6b9319edcdc0a49a43ef5c0a4c8265ca9ee7d6c67774"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:db7457bac39421addd0c8449933ac32d8042aae84a14911a757ae6ca3eef1392"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d094ddec350a2fb899fec68d8353c78233debde9b7d8b4beeafa70825f1c281a"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5845c1fd4866bb5dd3125d89b90e57ed3138241540897de748cdf19de8a2fca2"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9079dfc6a70abe341f521f78405b8949f96db48da98aeb43f9907f342f627cdc"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3914f5aaa0f36d5d60e8ece6a308ee1c9784cd75ec8151062614657a114c4478"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c08be4f460903e5a9d0f76818db3250f12e9c344e79314d1d570fc69d7f4eae4"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d093be959277cb7dee84b801eb1af388b6ad3ca6a6b6bf1ed7585895789d027d"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3702ea6872c5a2a4eeefa6ffd36b042e9773f05b1f37ae3ef7264b1163c2dcf6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2090f6a85cafc5b2db085124d752757c9d251548cedabe9bd31afe6363e0aff2"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:f67f217af4b1ff66c68a87318012de788dd95fcfeb24cc889011f4e1c7454dfd"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:189f652a87e876098bbc67b4da1049afb5f5dfbaa310dd67c594b01c10388db6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:6bb5992037f7a9eff7991ebe4273ea7f51f1c1c511e6a2ce511d0e7bdb754492"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f4c2b9e770c4e393876e35a7046879d195cd123b4f116d299d442b335bcd"}, + {file = "multidict-6.1.0-cp38-cp38-win32.whl", hash = "sha256:e27bbb6d14416713a8bd7aaa1313c0fc8d44ee48d74497a0ff4c3a1b6ccb5167"}, + {file = "multidict-6.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:22f3105d4fb15c8f57ff3959a58fcab6ce36814486500cd7485651230ad4d4ef"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4e18b656c5e844539d506a0a06432274d7bd52a7487e6828c63a63d69185626c"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a185f876e69897a6f3325c3f19f26a297fa058c5e456bfcff8015e9a27e83ae1"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab7c4ceb38d91570a650dba194e1ca87c2b543488fe9309b4212694174fd539c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e617fb6b0b6953fffd762669610c1c4ffd05632c138d61ac7e14ad187870669c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:16e5f4bf4e603eb1fdd5d8180f1a25f30056f22e55ce51fb3d6ad4ab29f7d96f"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4c035da3f544b1882bac24115f3e2e8760f10a0107614fc9839fd232200b875"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:957cf8e4b6e123a9eea554fa7ebc85674674b713551de587eb318a2df3e00255"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:483a6aea59cb89904e1ceabd2b47368b5600fb7de78a6e4a2c2987b2d256cf30"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:87701f25a2352e5bf7454caa64757642734da9f6b11384c1f9d1a8e699758057"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:682b987361e5fd7a139ed565e30d81fd81e9629acc7d925a205366877d8c8657"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce2186a7df133a9c895dea3331ddc5ddad42cdd0d1ea2f0a51e5d161e4762f28"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9f636b730f7e8cb19feb87094949ba54ee5357440b9658b2a32a5ce4bce53972"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:73eae06aa53af2ea5270cc066dcaf02cc60d2994bbb2c4ef5764949257d10f43"}, + {file = "multidict-6.1.0-cp39-cp39-win32.whl", hash = "sha256:1ca0083e80e791cffc6efce7660ad24af66c8d4079d2a750b29001b53ff59ada"}, + {file = "multidict-6.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:aa466da5b15ccea564bdab9c89175c762bc12825f4659c11227f515cee76fa4a"}, + {file = "multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506"}, + {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} + +[[package]] +name = "multiprocess" +version = "0.70.15" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"}, + {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"}, + {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"}, + {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"}, + {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"}, + {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"}, + {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"}, +] + +[package.dependencies] +dill = ">=0.3.7" + +[[package]] +name = "nest-asyncio" +version = "1.6.0" +description = "Patch asyncio to allow nested event loops" +optional = true +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + +[[package]] +name = "networkx" +version = "3.2.1" +description = "Python package for creating and manipulating graphs and networks" +optional = true +python-versions = ">=3.9" +files = [ + {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, + {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, +] + +[package.extras] +default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + +[[package]] +name = "numba" +version = "0.60.0" +description = "compiling Python code using LLVM" +optional = true +python-versions = ">=3.9" +files = [ + {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, + {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, + {file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"}, + {file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"}, + {file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"}, + {file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"}, + {file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"}, + {file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"}, + {file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"}, + {file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"}, + {file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"}, + {file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"}, + {file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"}, + {file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"}, + {file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"}, + {file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"}, + {file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"}, + {file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"}, + {file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"}, + {file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"}, + {file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"}, +] + +[package.dependencies] +llvmlite = "==0.43.*" +numpy = ">=1.22,<2.1" + +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.1.0.70" +description = "cuDNN runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, + {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-ml-py" +version = "12.560.30" +description = "Python Bindings for the NVIDIA Management Library" +optional = true +python-versions = "*" +files = [ + {file = "nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97"}, + {file = "nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73"}, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.20.5" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.6.77" +description = "Nvidia JIT LTO Library" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:3bf10d85bb1801e9c894c6e197e44dd137d2a0a9e43f8450e9ad13f2df0dd52d"}, + {file = "nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9ae346d16203ae4ea513be416495167a0101d33d2d14935aa9c1829a3fb45142"}, + {file = "nvidia_nvjitlink_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:410718cd44962bed862a31dd0318620f6f9a8b28a6291967bcfcb446a6516771"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + +[[package]] +name = "opentelemetry-api" +version = "1.25.0" +description = "OpenTelemetry Python API" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_api-1.25.0-py3-none-any.whl", hash = "sha256:757fa1aa020a0f8fa139f8959e53dec2051cc26b832e76fa839a6d76ecefd737"}, + {file = "opentelemetry_api-1.25.0.tar.gz", hash = "sha256:77c4985f62f2614e42ce77ee4c9da5fa5f0bc1e1821085e9a47533a9323ae869"}, +] + +[package.dependencies] +deprecated = ">=1.2.6" +importlib-metadata = ">=6.0,<=7.1" + +[[package]] +name = "opentelemetry-exporter-otlp" +version = "1.25.0" +description = "OpenTelemetry Collector Exporters" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_exporter_otlp-1.25.0-py3-none-any.whl", hash = "sha256:d67a831757014a3bc3174e4cd629ae1493b7ba8d189e8a007003cacb9f1a6b60"}, + {file = "opentelemetry_exporter_otlp-1.25.0.tar.gz", hash = "sha256:ce03199c1680a845f82e12c0a6a8f61036048c07ec7a0bd943142aca8fa6ced0"}, +] + +[package.dependencies] +opentelemetry-exporter-otlp-proto-grpc = "1.25.0" +opentelemetry-exporter-otlp-proto-http = "1.25.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.25.0" +description = "OpenTelemetry Protobuf encoding" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_exporter_otlp_proto_common-1.25.0-py3-none-any.whl", hash = "sha256:15637b7d580c2675f70246563363775b4e6de947871e01d0f4e3881d1848d693"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.25.0.tar.gz", hash = "sha256:c93f4e30da4eee02bacd1e004eb82ce4da143a2f8e15b987a9f603e0a85407d3"}, +] + +[package.dependencies] +opentelemetry-proto = "1.25.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-grpc" +version = "1.25.0" +description = "OpenTelemetry Collector Protobuf over gRPC Exporter" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0-py3-none-any.whl", hash = "sha256:3131028f0c0a155a64c430ca600fd658e8e37043cb13209f0109db5c1a3e4eb4"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0.tar.gz", hash = "sha256:c0b1661415acec5af87625587efa1ccab68b873745ca0ee96b69bb1042087eac"}, +] + +[package.dependencies] +deprecated = ">=1.2.6" +googleapis-common-protos = ">=1.52,<2.0" +grpcio = ">=1.0.0,<2.0.0" +opentelemetry-api = ">=1.15,<2.0" +opentelemetry-exporter-otlp-proto-common = "1.25.0" +opentelemetry-proto = "1.25.0" +opentelemetry-sdk = ">=1.25.0,<1.26.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.25.0" +description = "OpenTelemetry Collector Protobuf over HTTP Exporter" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_exporter_otlp_proto_http-1.25.0-py3-none-any.whl", hash = "sha256:2eca686ee11b27acd28198b3ea5e5863a53d1266b91cda47c839d95d5e0541a6"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.25.0.tar.gz", hash = "sha256:9f8723859e37c75183ea7afa73a3542f01d0fd274a5b97487ea24cb683d7d684"}, +] + +[package.dependencies] +deprecated = ">=1.2.6" +googleapis-common-protos = ">=1.52,<2.0" +opentelemetry-api = ">=1.15,<2.0" +opentelemetry-exporter-otlp-proto-common = "1.25.0" +opentelemetry-proto = "1.25.0" +opentelemetry-sdk = ">=1.25.0,<1.26.0" +requests = ">=2.7,<3.0" + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.46b0" +description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_instrumentation-0.46b0-py3-none-any.whl", hash = "sha256:89cd721b9c18c014ca848ccd11181e6b3fd3f6c7669e35d59c48dc527408c18b"}, + {file = "opentelemetry_instrumentation-0.46b0.tar.gz", hash = "sha256:974e0888fb2a1e01c38fbacc9483d024bb1132aad92d6d24e2e5543887a7adda"}, +] + +[package.dependencies] +opentelemetry-api = ">=1.4,<2.0" +setuptools = ">=16.0" +wrapt = ">=1.0.0,<2.0.0" + +[[package]] +name = "opentelemetry-instrumentation-grpc" +version = "0.46b0" +description = "OpenTelemetry gRPC instrumentation" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_instrumentation_grpc-0.46b0-py3-none-any.whl", hash = "sha256:cccfb28db07c28849709f2dcf330237fae0fca9f86971bfce27b28bb9a8b0577"}, + {file = "opentelemetry_instrumentation_grpc-0.46b0.tar.gz", hash = "sha256:9c5738592cf82672805099826b676d352324b54e03f9ac72a1368ba0605d6ff9"}, +] + +[package.dependencies] +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-instrumentation = "0.46b0" +opentelemetry-semantic-conventions = "0.46b0" +wrapt = ">=1.0.0,<2.0.0" + +[package.extras] +instruments = ["grpcio (>=1.27,<2.0)"] + +[[package]] +name = "opentelemetry-proto" +version = "1.25.0" +description = "OpenTelemetry Python Proto" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_proto-1.25.0-py3-none-any.whl", hash = "sha256:f07e3341c78d835d9b86665903b199893befa5e98866f63d22b00d0b7ca4972f"}, + {file = "opentelemetry_proto-1.25.0.tar.gz", hash = "sha256:35b6ef9dc4a9f7853ecc5006738ad40443701e52c26099e197895cbda8b815a3"}, +] + +[package.dependencies] +protobuf = ">=3.19,<5.0" + +[[package]] +name = "opentelemetry-sdk" +version = "1.25.0" +description = "OpenTelemetry Python SDK" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opentelemetry_sdk-1.25.0-py3-none-any.whl", hash = "sha256:d97ff7ec4b351692e9d5a15af570c693b8715ad78b8aafbec5c7100fe966b4c9"}, + {file = "opentelemetry_sdk-1.25.0.tar.gz", hash = "sha256:ce7fc319c57707ef5bf8b74fb9f8ebdb8bfafbe11898410e0d2a761d08a98ec7"}, +] + +[package.dependencies] +opentelemetry-api = "1.25.0" +opentelemetry-semantic-conventions = "0.46b0" +typing-extensions = ">=3.7.4" + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.46b0" +description = "OpenTelemetry Semantic Conventions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_semantic_conventions-0.36b0-py3-none-any.whl", hash = "sha256:adc05635e87b9d3e007c9f530eed487fc3ef2177d02f82f674f28ebf9aff8243"}, + {file = "opentelemetry_semantic_conventions-0.36b0.tar.gz", hash = "sha256:829dc221795467d98b773c04096e29be038d77526dc8d6ac76f546fb6279bf01"}, +] + +[[package]] +name = "optimum" +version = "1.23.2" +description = "Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality." +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "optimum-1.23.2-py3-none-any.whl", hash = "sha256:1b6f450b2c0203377b76cf4dd1e05e45b116d8069612fe2ad73df8b5743d5a57"}, + {file = "optimum-1.23.2.tar.gz", hash = "sha256:245415add6e621c6c5b7c29eee968bc325fc1a982d9c42eb83d97b7ec844ef39"}, +] + +[package.dependencies] +coloredlogs = "*" +datasets = "*" +huggingface-hub = ">=0.8.0" +numpy = "*" +packaging = "*" +sympy = "*" +torch = ">=1.11" +transformers = {version = ">=4.29", extras = ["sentencepiece"]} + +[package.extras] +amd = ["optimum-amd"] +benchmark = ["evaluate (>=0.2.0)", "optuna", "scikit-learn", "seqeval", "torchvision", "tqdm"] +dev = ["Pillow", "accelerate", "black (>=23.1,<24.0)", "diffusers (>=0.17.0)", "einops", "invisible-watermark", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "ruff (==0.1.5)", "sacremoses", "scikit-learn", "timm", "torchaudio", "torchvision"] +diffusers = ["diffusers"] +doc-build = ["accelerate"] +exporters = ["onnx", "onnxruntime", "timm", "transformers (<4.46.0)"] +exporters-gpu = ["onnx", "onnxruntime-gpu", "timm", "transformers (<4.46.0)"] +exporters-tf = ["datasets (<=2.16)", "h5py", "numpy (<1.24.0)", "onnx", "onnxruntime", "tensorflow (>=2.4,<=2.12.1)", "tf2onnx", "timm", "transformers[sentencepiece] (>=4.26,<4.38)"] +furiosa = ["optimum-furiosa"] +graphcore = ["optimum-graphcore"] +habana = ["optimum-habana", "transformers (>=4.45.0,<4.46.0)"] +intel = ["optimum-intel (>=1.18.0)"] +ipex = ["optimum-intel[ipex] (>=1.18.0)"] +neural-compressor = ["optimum-intel[neural-compressor] (>=1.18.0)"] +neuron = ["optimum-neuron[neuron] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"] +neuronx = ["optimum-neuron[neuronx] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"] +nncf = ["optimum-intel[nncf] (>=1.18.0)"] +onnxruntime = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime (>=1.11.0)", "protobuf (>=3.20.1)", "transformers (<4.46.0)"] +onnxruntime-gpu = ["accelerate", "datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime-gpu (>=1.11.0)", "protobuf (>=3.20.1)", "transformers (<4.46.0)"] +openvino = ["optimum-intel[openvino] (>=1.18.0)"] +quality = ["black (>=23.1,<24.0)", "ruff (==0.1.5)"] +quanto = ["optimum-quanto (>=0.2.4)"] +tests = ["Pillow", "accelerate", "diffusers (>=0.17.0)", "einops", "invisible-watermark", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "sacremoses", "scikit-learn", "timm", "torchaudio", "torchvision"] + +[[package]] +name = "optimum-habana" +version = "1.14.1" +description = "Optimum Habana is the interface between the Hugging Face Transformers and Diffusers libraries and Habana's Gaudi processor (HPU). It provides a set of tools enabling easy model loading, training and inference on single- and multi-HPU settings for different downstream tasks." +optional = false +python-versions = "*" +files = [ + {file = "optimum_habana-1.14.1-py3-none-any.whl", hash = "sha256:557dd06a5fa2597b1487384a354c3a1618390c35e6b088f5d7a80c0611fe2bfd"}, + {file = "optimum_habana-1.14.1.tar.gz", hash = "sha256:627626860b82452b75f73087df3e14004278ed30dacf767c100f1324675c8120"}, +] + +[package.dependencies] +accelerate = ">=0.33.0,<0.34.0" +diffusers = "0.29.2" +huggingface-hub = ">=0.24.7" +optimum = "*" +sentence-transformers = {version = "3.0.1", extras = ["train"]} +torch = "*" +transformers = ">=4.45.2,<4.46.0" + +[package.extras] +quality = ["hf-doc-builder", "ruff"] +tests = ["GitPython", "datasets", "optuna", "parameterized", "peft", "psutil", "pytest (<8.0.0)", "safetensors", "scipy", "sentencepiece", "timm", "torchsde"] + +[[package]] +name = "outlines" +version = "0.0.34" +description = "Probabilistic Generative Model Programming" +optional = true +python-versions = ">=3.8" +files = [ + {file = "outlines-0.0.34-py3-none-any.whl", hash = "sha256:911588a7e64a4f193b97fb4c501d98ccfd4e95a98f6a3ada67a280bf0c373c50"}, + {file = "outlines-0.0.34.tar.gz", hash = "sha256:594e7204c770b47a62eb5c2ba7d25ea0ab2e16882b5f04556712a0228d3d3309"}, +] + +[package.dependencies] +cloudpickle = "*" +diskcache = "*" +interegular = "*" +jinja2 = "*" +joblib = "*" +jsonschema = "*" +lark = "*" +nest-asyncio = "*" +numba = "*" +numpy = "*" +pydantic = ">=2.0" +referencing = "*" +requests = "*" +scipy = "*" +torch = ">=2.1.0" +transformers = "*" + +[package.extras] +serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"] +test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python (>=0.2.42)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] + +[[package]] +name = "packaging" +version = "24.1" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, +] + +[[package]] +name = "pandas" +version = "2.2.3" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = true +python-versions = ">=3.9" +files = [ + {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, + {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, + {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc6b93f9b966093cb0fd62ff1a7e4c09e6d546ad7c1de191767baffc57628f39"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5dbca4c1acd72e8eeef4753eeca07de9b1db4f398669d5994086f788a5d7cc30"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8cd6d7cc958a3910f934ea8dbdf17b2364827bb4dafc38ce6eef6bb3d65ff09c"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99df71520d25fade9db7c1076ac94eb994f4d2673ef2aa2e86ee039b6746d20c"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31d0ced62d4ea3e231a9f228366919a5ea0b07440d9d4dac345376fd8e1477ea"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7eee9e7cea6adf3e3d24e304ac6b8300646e2a5d1cd3a3c2abed9101b0846761"}, + {file = "pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.7" + +[package.extras] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] + +[[package]] +name = "peft" +version = "0.10.0" +description = "Parameter-Efficient Fine-Tuning (PEFT)" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "peft-0.10.0-py3-none-any.whl", hash = "sha256:d5249c97e818d3e31f92553c73c2953acd0ec12649b8b749afff7152cbc86cbb"}, + {file = "peft-0.10.0.tar.gz", hash = "sha256:36a7628c15f88d37abb26cfc74c22468f9037ee02e9c9b65de943cfe7c672049"}, +] + +[package.dependencies] +accelerate = ">=0.21.0" +huggingface-hub = ">=0.17.0" +numpy = ">=1.17" +packaging = ">=20.0" +psutil = "*" +pyyaml = "*" +safetensors = "*" +torch = ">=1.13.0" +tqdm = "*" +transformers = "*" + +[package.extras] +dev = ["black", "hf-doc-builder", "ruff (>=0.2.1,<0.3.0)"] +docs-specific = ["black", "hf-doc-builder"] +quality = ["black", "hf-doc-builder", "ruff (>=0.2.1,<0.3.0)"] +test = ["black", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.2.1,<0.3.0)", "scipy"] + +[[package]] +name = "pillow" +version = "11.0.0" +description = "Python Imaging Library (Fork)" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pillow-11.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947"}, + {file = "pillow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97"}, + {file = "pillow-11.0.0-cp310-cp310-win32.whl", hash = "sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50"}, + {file = "pillow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c"}, + {file = "pillow-11.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9"}, + {file = "pillow-11.0.0-cp311-cp311-win32.whl", hash = "sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5"}, + {file = "pillow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291"}, + {file = "pillow-11.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc"}, + {file = "pillow-11.0.0-cp312-cp312-win32.whl", hash = "sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6"}, + {file = "pillow-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47"}, + {file = "pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb"}, + {file = "pillow-11.0.0-cp313-cp313-win32.whl", hash = "sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798"}, + {file = "pillow-11.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de"}, + {file = "pillow-11.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a"}, + {file = "pillow-11.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8"}, + {file = "pillow-11.0.0-cp313-cp313t-win32.whl", hash = "sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8"}, + {file = "pillow-11.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904"}, + {file = "pillow-11.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae"}, + {file = "pillow-11.0.0-cp39-cp39-win32.whl", hash = "sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4"}, + {file = "pillow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd"}, + {file = "pillow-11.0.0-cp39-cp39-win_arm64.whl", hash = "sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944"}, + {file = "pillow-11.0.0.tar.gz", hash = "sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +typing = ["typing-extensions"] +xmp = ["defusedxml"] + +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "prometheus-client" +version = "0.20.0" +description = "Python client for the Prometheus monitoring system." +optional = false +python-versions = ">=3.8" +files = [ + {file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"}, + {file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"}, +] + +[package.extras] +twisted = ["twisted"] + +[[package]] +name = "propcache" +version = "0.2.0" +description = "Accelerated property cache" +optional = false +python-versions = ">=3.8" +files = [ + {file = "propcache-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c5869b8fd70b81835a6f187c5fdbe67917a04d7e52b6e7cc4e5fe39d55c39d58"}, + {file = "propcache-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:952e0d9d07609d9c5be361f33b0d6d650cd2bae393aabb11d9b719364521984b"}, + {file = "propcache-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:33ac8f098df0585c0b53009f039dfd913b38c1d2edafed0cedcc0c32a05aa110"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97e48e8875e6c13909c800fa344cd54cc4b2b0db1d5f911f840458a500fde2c2"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:388f3217649d6d59292b722d940d4d2e1e6a7003259eb835724092a1cca0203a"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f571aea50ba5623c308aa146eb650eebf7dbe0fd8c5d946e28343cb3b5aad577"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3dfafb44f7bb35c0c06eda6b2ab4bfd58f02729e7c4045e179f9a861b07c9850"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3ebe9a75be7ab0b7da2464a77bb27febcb4fab46a34f9288f39d74833db7f61"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d2f0d0f976985f85dfb5f3d685697ef769faa6b71993b46b295cdbbd6be8cc37"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a3dc1a4b165283bd865e8f8cb5f0c64c05001e0718ed06250d8cac9bec115b48"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9e0f07b42d2a50c7dd2d8675d50f7343d998c64008f1da5fef888396b7f84630"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e63e3e1e0271f374ed489ff5ee73d4b6e7c60710e1f76af5f0e1a6117cd26394"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:56bb5c98f058a41bb58eead194b4db8c05b088c93d94d5161728515bd52b052b"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7665f04d0c7f26ff8bb534e1c65068409bf4687aa2534faf7104d7182debb336"}, + {file = "propcache-0.2.0-cp310-cp310-win32.whl", hash = "sha256:7cf18abf9764746b9c8704774d8b06714bcb0a63641518a3a89c7f85cc02c2ad"}, + {file = "propcache-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:cfac69017ef97db2438efb854edf24f5a29fd09a536ff3a992b75990720cdc99"}, + {file = "propcache-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:63f13bf09cc3336eb04a837490b8f332e0db41da66995c9fd1ba04552e516354"}, + {file = "propcache-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:608cce1da6f2672a56b24a015b42db4ac612ee709f3d29f27a00c943d9e851de"}, + {file = "propcache-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:466c219deee4536fbc83c08d09115249db301550625c7fef1c5563a584c9bc87"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc2db02409338bf36590aa985a461b2c96fce91f8e7e0f14c50c5fcc4f229016"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6ed8db0a556343d566a5c124ee483ae113acc9a557a807d439bcecc44e7dfbb"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91997d9cb4a325b60d4e3f20967f8eb08dfcb32b22554d5ef78e6fd1dda743a2"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c7dde9e533c0a49d802b4f3f218fa9ad0a1ce21f2c2eb80d5216565202acab4"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffcad6c564fe6b9b8916c1aefbb37a362deebf9394bd2974e9d84232e3e08504"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:97a58a28bcf63284e8b4d7b460cbee1edaab24634e82059c7b8c09e65284f178"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:945db8ee295d3af9dbdbb698cce9bbc5c59b5c3fe328bbc4387f59a8a35f998d"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:39e104da444a34830751715f45ef9fc537475ba21b7f1f5b0f4d71a3b60d7fe2"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:c5ecca8f9bab618340c8e848d340baf68bcd8ad90a8ecd7a4524a81c1764b3db"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:c436130cc779806bdf5d5fae0d848713105472b8566b75ff70048c47d3961c5b"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:191db28dc6dcd29d1a3e063c3be0b40688ed76434622c53a284e5427565bbd9b"}, + {file = "propcache-0.2.0-cp311-cp311-win32.whl", hash = "sha256:5f2564ec89058ee7c7989a7b719115bdfe2a2fb8e7a4543b8d1c0cc4cf6478c1"}, + {file = "propcache-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:6e2e54267980349b723cff366d1e29b138b9a60fa376664a157a342689553f71"}, + {file = "propcache-0.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:2ee7606193fb267be4b2e3b32714f2d58cad27217638db98a60f9efb5efeccc2"}, + {file = "propcache-0.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:91ee8fc02ca52e24bcb77b234f22afc03288e1dafbb1f88fe24db308910c4ac7"}, + {file = "propcache-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e900bad2a8456d00a113cad8c13343f3b1f327534e3589acc2219729237a2e8"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f52a68c21363c45297aca15561812d542f8fc683c85201df0bebe209e349f793"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e41d67757ff4fbc8ef2af99b338bfb955010444b92929e9e55a6d4dcc3c4f09"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a64e32f8bd94c105cc27f42d3b658902b5bcc947ece3c8fe7bc1b05982f60e89"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55346705687dbd7ef0d77883ab4f6fabc48232f587925bdaf95219bae072491e"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00181262b17e517df2cd85656fcd6b4e70946fe62cd625b9d74ac9977b64d8d9"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6994984550eaf25dd7fc7bd1b700ff45c894149341725bb4edc67f0ffa94efa4"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:56295eb1e5f3aecd516d91b00cfd8bf3a13991de5a479df9e27dd569ea23959c"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:439e76255daa0f8151d3cb325f6dd4a3e93043e6403e6491813bcaaaa8733887"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f6475a1b2ecb310c98c28d271a30df74f9dd436ee46d09236a6b750a7599ce57"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3444cdba6628accf384e349014084b1cacd866fbb88433cd9d279d90a54e0b23"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4a9d9b4d0a9b38d1c391bb4ad24aa65f306c6f01b512e10a8a34a2dc5675d348"}, + {file = "propcache-0.2.0-cp312-cp312-win32.whl", hash = "sha256:69d3a98eebae99a420d4b28756c8ce6ea5a29291baf2dc9ff9414b42676f61d5"}, + {file = "propcache-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ad9c9b99b05f163109466638bd30ada1722abb01bbb85c739c50b6dc11f92dc3"}, + {file = "propcache-0.2.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ecddc221a077a8132cf7c747d5352a15ed763b674c0448d811f408bf803d9ad7"}, + {file = "propcache-0.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0e53cb83fdd61cbd67202735e6a6687a7b491c8742dfc39c9e01e80354956763"}, + {file = "propcache-0.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92fe151145a990c22cbccf9ae15cae8ae9eddabfc949a219c9f667877e40853d"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6a21ef516d36909931a2967621eecb256018aeb11fc48656e3257e73e2e247a"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3f88a4095e913f98988f5b338c1d4d5d07dbb0b6bad19892fd447484e483ba6b"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a5b3bb545ead161be780ee85a2b54fdf7092815995661947812dde94a40f6fb"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67aeb72e0f482709991aa91345a831d0b707d16b0257e8ef88a2ad246a7280bf"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c997f8c44ec9b9b0bcbf2d422cc00a1d9b9c681f56efa6ca149a941e5560da2"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2a66df3d4992bc1d725b9aa803e8c5a66c010c65c741ad901e260ece77f58d2f"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:3ebbcf2a07621f29638799828b8d8668c421bfb94c6cb04269130d8de4fb7136"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1235c01ddaa80da8235741e80815ce381c5267f96cc49b1477fdcf8c047ef325"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3947483a381259c06921612550867b37d22e1df6d6d7e8361264b6d037595f44"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d5bed7f9805cc29c780f3aee05de3262ee7ce1f47083cfe9f77471e9d6777e83"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4a91d44379f45f5e540971d41e4626dacd7f01004826a18cb048e7da7e96544"}, + {file = "propcache-0.2.0-cp313-cp313-win32.whl", hash = "sha256:f902804113e032e2cdf8c71015651c97af6418363bea8d78dc0911d56c335032"}, + {file = "propcache-0.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:8f188cfcc64fb1266f4684206c9de0e80f54622c3f22a910cbd200478aeae61e"}, + {file = "propcache-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:53d1bd3f979ed529f0805dd35ddaca330f80a9a6d90bc0121d2ff398f8ed8861"}, + {file = "propcache-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:83928404adf8fb3d26793665633ea79b7361efa0287dfbd372a7e74311d51ee6"}, + {file = "propcache-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77a86c261679ea5f3896ec060be9dc8e365788248cc1e049632a1be682442063"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:218db2a3c297a3768c11a34812e63b3ac1c3234c3a086def9c0fee50d35add1f"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7735e82e3498c27bcb2d17cb65d62c14f1100b71723b68362872bca7d0913d90"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:20a617c776f520c3875cf4511e0d1db847a076d720714ae35ffe0df3e440be68"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67b69535c870670c9f9b14a75d28baa32221d06f6b6fa6f77a0a13c5a7b0a5b9"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4569158070180c3855e9c0791c56be3ceeb192defa2cdf6a3f39e54319e56b89"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:db47514ffdbd91ccdc7e6f8407aac4ee94cc871b15b577c1c324236b013ddd04"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_armv7l.whl", hash = "sha256:2a60ad3e2553a74168d275a0ef35e8c0a965448ffbc3b300ab3a5bb9956c2162"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:662dd62358bdeaca0aee5761de8727cfd6861432e3bb828dc2a693aa0471a563"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:25a1f88b471b3bc911d18b935ecb7115dff3a192b6fef46f0bfaf71ff4f12418"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:f60f0ac7005b9f5a6091009b09a419ace1610e163fa5deaba5ce3484341840e7"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:74acd6e291f885678631b7ebc85d2d4aec458dd849b8c841b57ef04047833bed"}, + {file = "propcache-0.2.0-cp38-cp38-win32.whl", hash = "sha256:d9b6ddac6408194e934002a69bcaadbc88c10b5f38fb9307779d1c629181815d"}, + {file = "propcache-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:676135dcf3262c9c5081cc8f19ad55c8a64e3f7282a21266d05544450bffc3a5"}, + {file = "propcache-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:25c8d773a62ce0451b020c7b29a35cfbc05de8b291163a7a0f3b7904f27253e6"}, + {file = "propcache-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:375a12d7556d462dc64d70475a9ee5982465fbb3d2b364f16b86ba9135793638"}, + {file = "propcache-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1ec43d76b9677637a89d6ab86e1fef70d739217fefa208c65352ecf0282be957"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f45eec587dafd4b2d41ac189c2156461ebd0c1082d2fe7013571598abb8505d1"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc092ba439d91df90aea38168e11f75c655880c12782facf5cf9c00f3d42b562"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa1076244f54bb76e65e22cb6910365779d5c3d71d1f18b275f1dfc7b0d71b4d"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:682a7c79a2fbf40f5dbb1eb6bfe2cd865376deeac65acf9beb607505dced9e12"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e40876731f99b6f3c897b66b803c9e1c07a989b366c6b5b475fafd1f7ba3fb8"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:363ea8cd3c5cb6679f1c2f5f1f9669587361c062e4899fce56758efa928728f8"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:140fbf08ab3588b3468932974a9331aff43c0ab8a2ec2c608b6d7d1756dbb6cb"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e70fac33e8b4ac63dfc4c956fd7d85a0b1139adcfc0d964ce288b7c527537fea"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:b33d7a286c0dc1a15f5fc864cc48ae92a846df287ceac2dd499926c3801054a6"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:f6d5749fdd33d90e34c2efb174c7e236829147a2713334d708746e94c4bde40d"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22aa8f2272d81d9317ff5756bb108021a056805ce63dd3630e27d042c8092798"}, + {file = "propcache-0.2.0-cp39-cp39-win32.whl", hash = "sha256:73e4b40ea0eda421b115248d7e79b59214411109a5bc47d0d48e4c73e3b8fcf9"}, + {file = "propcache-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:9517d5e9e0731957468c29dbfd0f976736a0e55afaea843726e887f36fe017df"}, + {file = "propcache-0.2.0-py3-none-any.whl", hash = "sha256:2ccc28197af5313706511fab3a8b66dcd6da067a1331372c82ea1cb74285e036"}, + {file = "propcache-0.2.0.tar.gz", hash = "sha256:df81779732feb9d01e5d513fad0122efb3d53bbc75f61b2a4f29a020bc985e70"}, +] + +[[package]] +name = "protobuf" +version = "4.25.5" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-4.25.5-cp310-abi3-win32.whl", hash = "sha256:5e61fd921603f58d2f5acb2806a929b4675f8874ff5f330b7d6f7e2e784bbcd8"}, + {file = "protobuf-4.25.5-cp310-abi3-win_amd64.whl", hash = "sha256:4be0571adcbe712b282a330c6e89eae24281344429ae95c6d85e79e84780f5ea"}, + {file = "protobuf-4.25.5-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b2fde3d805354df675ea4c7c6338c1aecd254dfc9925e88c6d31a2bcb97eb173"}, + {file = "protobuf-4.25.5-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:919ad92d9b0310070f8356c24b855c98df2b8bd207ebc1c0c6fcc9ab1e007f3d"}, + {file = "protobuf-4.25.5-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:fe14e16c22be926d3abfcb500e60cab068baf10b542b8c858fa27e098123e331"}, + {file = "protobuf-4.25.5-cp38-cp38-win32.whl", hash = "sha256:98d8d8aa50de6a2747efd9cceba361c9034050ecce3e09136f90de37ddba66e1"}, + {file = "protobuf-4.25.5-cp38-cp38-win_amd64.whl", hash = "sha256:b0234dd5a03049e4ddd94b93400b67803c823cfc405689688f59b34e0742381a"}, + {file = "protobuf-4.25.5-cp39-cp39-win32.whl", hash = "sha256:abe32aad8561aa7cc94fc7ba4fdef646e576983edb94a73381b03c53728a626f"}, + {file = "protobuf-4.25.5-cp39-cp39-win_amd64.whl", hash = "sha256:7a183f592dc80aa7c8da7ad9e55091c4ffc9497b3054452d629bb85fa27c2a45"}, + {file = "protobuf-4.25.5-py3-none-any.whl", hash = "sha256:0aebecb809cae990f8129ada5ca273d9d670b76d9bfc9b1809f0a9c02b7dbf41"}, + {file = "protobuf-4.25.5.tar.gz", hash = "sha256:7f8249476b4a9473645db7f8ab42b02fe1488cbe5fb72fddd445e0665afd8584"}, +] + +[[package]] +name = "psutil" +version = "6.1.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = true +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:000d1d1ebd634b4efb383f4034437384e44a6d455260aaee2eca1e9c1b55f047"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:5cd2bcdc75b452ba2e10f0e8ecc0b57b827dd5d7aaffbc6821b2a9a242823a76"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:045f00a43c737f960d273a83973b2511430d61f283a44c96bf13a6e829ba8fdc"}, + {file = "psutil-6.1.0-cp27-none-win32.whl", hash = "sha256:9118f27452b70bb1d9ab3198c1f626c2499384935aaf55388211ad982611407e"}, + {file = "psutil-6.1.0-cp27-none-win_amd64.whl", hash = "sha256:a8506f6119cff7015678e2bce904a4da21025cc70ad283a53b099e7620061d85"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6e2dcd475ce8b80522e51d923d10c7871e45f20918e027ab682f94f1c6351688"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0895b8414afafc526712c498bd9de2b063deaac4021a3b3c34566283464aff8e"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dcbfce5d89f1d1f2546a2090f4fcf87c7f669d1d90aacb7d7582addece9fb38"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498c6979f9c6637ebc3a73b3f87f9eb1ec24e1ce53a7c5173b8508981614a90b"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d905186d647b16755a800e7263d43df08b790d709d575105d419f8b6ef65423a"}, + {file = "psutil-6.1.0-cp36-cp36m-win32.whl", hash = "sha256:6d3fbbc8d23fcdcb500d2c9f94e07b1342df8ed71b948a2649b5cb060a7c94ca"}, + {file = "psutil-6.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1209036fbd0421afde505a4879dee3b2fd7b1e14fee81c0069807adcbbcca747"}, + {file = "psutil-6.1.0-cp37-abi3-win32.whl", hash = "sha256:1ad45a1f5d0b608253b11508f80940985d1d0c8f6111b5cb637533a0e6ddc13e"}, + {file = "psutil-6.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be"}, + {file = "psutil-6.1.0.tar.gz", hash = "sha256:353815f59a7f64cdaca1c0307ee13558a0512f6db064e92fe833784f08539c7a"}, +] + +[package.extras] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["pytest", "pytest-xdist", "setuptools"] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +description = "Get CPU info with pure Python" +optional = false +python-versions = "*" +files = [ + {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"}, + {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, +] + +[[package]] +name = "pyarrow" +version = "17.0.0" +description = "Python library for Apache Arrow" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + +[[package]] +name = "pydantic" +version = "2.9.2" +description = "Data validation using Python type hints" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"}, + {file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"}, +] + +[package.dependencies] +annotated-types = ">=0.6.0" +pydantic-core = "2.23.4" +typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} + +[package.extras] +email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata"] + +[[package]] +name = "pydantic-core" +version = "2.23.4" +description = "Core functionality for Pydantic validation and serialization" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"}, + {file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"}, + {file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"}, + {file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"}, + {file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"}, + {file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"}, + {file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"}, + {file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"}, + {file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb"}, + {file = "pydantic_core-2.23.4-cp38-none-win32.whl", hash = "sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6"}, + {file = "pydantic_core-2.23.4-cp38-none-win_amd64.whl", hash = "sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605"}, + {file = "pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6"}, + {file = "pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e"}, + {file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pyreadline3" +version = "3.5.4" +description = "A python implementation of GNU readline." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6"}, + {file = "pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7"}, +] + +[package.extras] +dev = ["build", "flake8", "mypy", "pytest", "twine"] + +[[package]] +name = "pytest" +version = "7.4.4" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, + {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = true +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytz" +version = "2024.2" +description = "World timezone definitions, modern and historical" +optional = true +python-versions = "*" +files = [ + {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, + {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, +] + +[[package]] +name = "pyyaml" +version = "6.0.2" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, +] + +[[package]] +name = "referencing" +version = "0.35.1" +description = "JSON Referencing + Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + +[[package]] +name = "regex" +version = "2024.9.11" +description = "Alternative regular expression module, to replace re." +optional = false +python-versions = ">=3.8" +files = [ + {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1494fa8725c285a81d01dc8c06b55287a1ee5e0e382d8413adc0a9197aac6408"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0e12c481ad92d129c78f13a2a3662317e46ee7ef96c94fd332e1c29131875b7d"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16e13a7929791ac1216afde26f712802e3df7bf0360b32e4914dca3ab8baeea5"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46989629904bad940bbec2106528140a218b4a36bb3042d8406980be1941429c"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a906ed5e47a0ce5f04b2c981af1c9acf9e8696066900bf03b9d7879a6f679fc8"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a091b0550b3b0207784a7d6d0f1a00d1d1c8a11699c1a4d93db3fbefc3ad35"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ddcd9a179c0a6fa8add279a4444015acddcd7f232a49071ae57fa6e278f1f71"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6b41e1adc61fa347662b09398e31ad446afadff932a24807d3ceb955ed865cc8"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ced479f601cd2f8ca1fd7b23925a7e0ad512a56d6e9476f79b8f381d9d37090a"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:635a1d96665f84b292e401c3d62775851aedc31d4f8784117b3c68c4fcd4118d"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c0256beda696edcf7d97ef16b2a33a8e5a875affd6fa6567b54f7c577b30a137"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ce4f1185db3fbde8ed8aa223fc9620f276c58de8b0d4f8cc86fd1360829edb6"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:09d77559e80dcc9d24570da3745ab859a9cf91953062e4ab126ba9d5993688ca"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a22ccefd4db3f12b526eccb129390942fe874a3a9fdbdd24cf55773a1faab1a"}, + {file = "regex-2024.9.11-cp310-cp310-win32.whl", hash = "sha256:f745ec09bc1b0bd15cfc73df6fa4f726dcc26bb16c23a03f9e3367d357eeedd0"}, + {file = "regex-2024.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:01c2acb51f8a7d6494c8c5eafe3d8e06d76563d8a8a4643b37e9b2dd8a2ff623"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2cce2449e5927a0bf084d346da6cd5eb016b2beca10d0013ab50e3c226ffc0df"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b37fa423beefa44919e009745ccbf353d8c981516e807995b2bd11c2c77d268"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:64ce2799bd75039b480cc0360907c4fb2f50022f030bf9e7a8705b636e408fad"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4cc92bb6db56ab0c1cbd17294e14f5e9224f0cc6521167ef388332604e92679"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d05ac6fa06959c4172eccd99a222e1fbf17b5670c4d596cb1e5cde99600674c4"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:040562757795eeea356394a7fb13076ad4f99d3c62ab0f8bdfb21f99a1f85664"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6113c008a7780792efc80f9dfe10ba0cd043cbf8dc9a76ef757850f51b4edc50"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e5fb5f77c8745a60105403a774fe2c1759b71d3e7b4ca237a5e67ad066c7199"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:54d9ff35d4515debf14bc27f1e3b38bfc453eff3220f5bce159642fa762fe5d4"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:df5cbb1fbc74a8305b6065d4ade43b993be03dbe0f8b30032cced0d7740994bd"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7fb89ee5d106e4a7a51bce305ac4efb981536301895f7bdcf93ec92ae0d91c7f"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a738b937d512b30bf75995c0159c0ddf9eec0775c9d72ac0202076c72f24aa96"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e28f9faeb14b6f23ac55bfbbfd3643f5c7c18ede093977f1df249f73fd22c7b1"}, + {file = "regex-2024.9.11-cp311-cp311-win32.whl", hash = "sha256:18e707ce6c92d7282dfce370cd205098384b8ee21544e7cb29b8aab955b66fa9"}, + {file = "regex-2024.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:313ea15e5ff2a8cbbad96ccef6be638393041b0a7863183c2d31e0c6116688cf"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b0d0a6c64fcc4ef9c69bd5b3b3626cc3776520a1637d8abaa62b9edc147a58f7"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:49b0e06786ea663f933f3710a51e9385ce0cba0ea56b67107fd841a55d56a231"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5b513b6997a0b2f10e4fd3a1313568e373926e8c252bd76c960f96fd039cd28d"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee439691d8c23e76f9802c42a95cfeebf9d47cf4ffd06f18489122dbb0a7ad64"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a8f877c89719d759e52783f7fe6e1c67121076b87b40542966c02de5503ace42"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23b30c62d0f16827f2ae9f2bb87619bc4fba2044911e2e6c2eb1af0161cdb766"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85ab7824093d8f10d44330fe1e6493f756f252d145323dd17ab6b48733ff6c0a"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8dee5b4810a89447151999428fe096977346cf2f29f4d5e29609d2e19e0199c9"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98eeee2f2e63edae2181c886d7911ce502e1292794f4c5ee71e60e23e8d26b5d"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:57fdd2e0b2694ce6fc2e5ccf189789c3e2962916fb38779d3e3521ff8fe7a822"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d552c78411f60b1fdaafd117a1fca2f02e562e309223b9d44b7de8be451ec5e0"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a0b2b80321c2ed3fcf0385ec9e51a12253c50f146fddb2abbb10f033fe3d049a"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:18406efb2f5a0e57e3a5881cd9354c1512d3bb4f5c45d96d110a66114d84d23a"}, + {file = "regex-2024.9.11-cp312-cp312-win32.whl", hash = "sha256:e464b467f1588e2c42d26814231edecbcfe77f5ac414d92cbf4e7b55b2c2a776"}, + {file = "regex-2024.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:9e8719792ca63c6b8340380352c24dcb8cd7ec49dae36e963742a275dfae6009"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c157bb447303070f256e084668b702073db99bbb61d44f85d811025fcf38f784"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4db21ece84dfeefc5d8a3863f101995de646c6cb0536952c321a2650aa202c36"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:220e92a30b426daf23bb67a7962900ed4613589bab80382be09b48896d211e92"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb1ae19e64c14c7ec1995f40bd932448713d3c73509e82d8cd7744dc00e29e86"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f47cd43a5bfa48f86925fe26fbdd0a488ff15b62468abb5d2a1e092a4fb10e85"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d4a76b96f398697fe01117093613166e6aa8195d63f1b4ec3f21ab637632963"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ea51dcc0835eea2ea31d66456210a4e01a076d820e9039b04ae8d17ac11dee6"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7aaa315101c6567a9a45d2839322c51c8d6e81f67683d529512f5bcfb99c802"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c57d08ad67aba97af57a7263c2d9006d5c404d721c5f7542f077f109ec2a4a29"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8404bf61298bb6f8224bb9176c1424548ee1181130818fcd2cbffddc768bed8"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:dd4490a33eb909ef5078ab20f5f000087afa2a4daa27b4c072ccb3cb3050ad84"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:eee9130eaad130649fd73e5cd92f60e55708952260ede70da64de420cdcad554"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6a2644a93da36c784e546de579ec1806bfd2763ef47babc1b03d765fe560c9f8"}, + {file = "regex-2024.9.11-cp313-cp313-win32.whl", hash = "sha256:e997fd30430c57138adc06bba4c7c2968fb13d101e57dd5bb9355bf8ce3fa7e8"}, + {file = "regex-2024.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:042c55879cfeb21a8adacc84ea347721d3d83a159da6acdf1116859e2427c43f"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:35f4a6f96aa6cb3f2f7247027b07b15a374f0d5b912c0001418d1d55024d5cb4"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:55b96e7ce3a69a8449a66984c268062fbaa0d8ae437b285428e12797baefce7e"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb130fccd1a37ed894824b8c046321540263013da72745d755f2d35114b81a60"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:323c1f04be6b2968944d730e5c2091c8c89767903ecaa135203eec4565ed2b2b"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be1c8ed48c4c4065ecb19d882a0ce1afe0745dfad8ce48c49586b90a55f02366"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b5b029322e6e7b94fff16cd120ab35a253236a5f99a79fb04fda7ae71ca20ae8"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6fff13ef6b5f29221d6904aa816c34701462956aa72a77f1f151a8ec4f56aeb"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:587d4af3979376652010e400accc30404e6c16b7df574048ab1f581af82065e4"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:079400a8269544b955ffa9e31f186f01d96829110a3bf79dc338e9910f794fca"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f9268774428ec173654985ce55fc6caf4c6d11ade0f6f914d48ef4719eb05ebb"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:23f9985c8784e544d53fc2930fc1ac1a7319f5d5332d228437acc9f418f2f168"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2941333154baff9838e88aa71c1d84f4438189ecc6021a12c7573728b5838e"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:e93f1c331ca8e86fe877a48ad64e77882c0c4da0097f2212873a69bbfea95d0c"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:846bc79ee753acf93aef4184c040d709940c9d001029ceb7b7a52747b80ed2dd"}, + {file = "regex-2024.9.11-cp38-cp38-win32.whl", hash = "sha256:c94bb0a9f1db10a1d16c00880bdebd5f9faf267273b8f5bd1878126e0fbde771"}, + {file = "regex-2024.9.11-cp38-cp38-win_amd64.whl", hash = "sha256:2b08fce89fbd45664d3df6ad93e554b6c16933ffa9d55cb7e01182baaf971508"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:07f45f287469039ffc2c53caf6803cd506eb5f5f637f1d4acb37a738f71dd066"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4838e24ee015101d9f901988001038f7f0d90dc0c3b115541a1365fb439add62"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6edd623bae6a737f10ce853ea076f56f507fd7726bee96a41ee3d68d347e4d16"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c69ada171c2d0e97a4b5aa78fbb835e0ffbb6b13fc5da968c09811346564f0d3"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02087ea0a03b4af1ed6ebab2c54d7118127fee8d71b26398e8e4b05b78963199"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69dee6a020693d12a3cf892aba4808fe168d2a4cef368eb9bf74f5398bfd4ee8"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:297f54910247508e6e5cae669f2bc308985c60540a4edd1c77203ef19bfa63ca"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ecea58b43a67b1b79805f1a0255730edaf5191ecef84dbc4cc85eb30bc8b63b9"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eab4bb380f15e189d1313195b062a6aa908f5bd687a0ceccd47c8211e9cf0d4a"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0cbff728659ce4bbf4c30b2a1be040faafaa9eca6ecde40aaff86f7889f4ab39"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:54c4a097b8bc5bb0dfc83ae498061d53ad7b5762e00f4adaa23bee22b012e6ba"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:73d6d2f64f4d894c96626a75578b0bf7d9e56dcda8c3d037a2118fdfe9b1c664"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:e53b5fbab5d675aec9f0c501274c467c0f9a5d23696cfc94247e1fb56501ed89"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ffbcf9221e04502fc35e54d1ce9567541979c3fdfb93d2c554f0ca583a19b35"}, + {file = "regex-2024.9.11-cp39-cp39-win32.whl", hash = "sha256:e4c22e1ac1f1ec1e09f72e6c44d8f2244173db7eb9629cc3a346a8d7ccc31142"}, + {file = "regex-2024.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:faa3c142464efec496967359ca99696c896c591c56c53506bac1ad465f66e919"}, + {file = "regex-2024.9.11.tar.gz", hash = "sha256:6c188c307e8433bcb63dc1915022deb553b4203a70722fc542c363bf120a01fd"}, +] + +[[package]] +name = "requests" +version = "2.32.3" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.8" +files = [ + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "rich" +version = "13.8.1" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"}, + {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + +[[package]] +name = "rpds-py" +version = "0.20.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = true +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, +] + +[[package]] +name = "safetensors" +version = "0.4.5" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "safetensors-0.4.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a63eaccd22243c67e4f2b1c3e258b257effc4acd78f3b9d397edc8cf8f1298a7"}, + {file = "safetensors-0.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:23fc9b4ec7b602915cbb4ec1a7c1ad96d2743c322f20ab709e2c35d1b66dad27"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6885016f34bef80ea1085b7e99b3c1f92cb1be78a49839203060f67b40aee761"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:133620f443450429322f238fda74d512c4008621227fccf2f8cf4a76206fea7c"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4fb3e0609ec12d2a77e882f07cced530b8262027f64b75d399f1504ffec0ba56"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0f1dd769f064adc33831f5e97ad07babbd728427f98e3e1db6902e369122737"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6d156bdb26732feada84f9388a9f135528c1ef5b05fae153da365ad4319c4c5"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e347d77e2c77eb7624400ccd09bed69d35c0332f417ce8c048d404a096c593b"}, + {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9f556eea3aec1d3d955403159fe2123ddd68e880f83954ee9b4a3f2e15e716b6"}, + {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9483f42be3b6bc8ff77dd67302de8ae411c4db39f7224dec66b0eb95822e4163"}, + {file = "safetensors-0.4.5-cp310-none-win32.whl", hash = "sha256:7389129c03fadd1ccc37fd1ebbc773f2b031483b04700923c3511d2a939252cc"}, + {file = "safetensors-0.4.5-cp310-none-win_amd64.whl", hash = "sha256:e98ef5524f8b6620c8cdef97220c0b6a5c1cef69852fcd2f174bb96c2bb316b1"}, + {file = "safetensors-0.4.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:21f848d7aebd5954f92538552d6d75f7c1b4500f51664078b5b49720d180e47c"}, + {file = "safetensors-0.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb07000b19d41e35eecef9a454f31a8b4718a185293f0d0b1c4b61d6e4487971"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09dedf7c2fda934ee68143202acff6e9e8eb0ddeeb4cfc24182bef999efa9f42"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59b77e4b7a708988d84f26de3ebead61ef1659c73dcbc9946c18f3b1786d2688"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d3bc83e14d67adc2e9387e511097f254bd1b43c3020440e708858c684cbac68"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39371fc551c1072976073ab258c3119395294cf49cdc1f8476794627de3130df"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6c19feda32b931cae0acd42748a670bdf56bee6476a046af20181ad3fee4090"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a659467495de201e2f282063808a41170448c78bada1e62707b07a27b05e6943"}, + {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bad5e4b2476949bcd638a89f71b6916fa9a5cae5c1ae7eede337aca2100435c0"}, + {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a3a315a6d0054bc6889a17f5668a73f94f7fe55121ff59e0a199e3519c08565f"}, + {file = "safetensors-0.4.5-cp311-none-win32.whl", hash = "sha256:a01e232e6d3d5cf8b1667bc3b657a77bdab73f0743c26c1d3c5dd7ce86bd3a92"}, + {file = "safetensors-0.4.5-cp311-none-win_amd64.whl", hash = "sha256:cbd39cae1ad3e3ef6f63a6f07296b080c951f24cec60188378e43d3713000c04"}, + {file = "safetensors-0.4.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:473300314e026bd1043cef391bb16a8689453363381561b8a3e443870937cc1e"}, + {file = "safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:801183a0f76dc647f51a2d9141ad341f9665602a7899a693207a82fb102cc53e"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1524b54246e422ad6fb6aea1ac71edeeb77666efa67230e1faf6999df9b2e27f"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3139098e3e8b2ad7afbca96d30ad29157b50c90861084e69fcb80dec7430461"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65573dc35be9059770808e276b017256fa30058802c29e1038eb1c00028502ea"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd33da8e9407559f8779c82a0448e2133737f922d71f884da27184549416bfed"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3685ce7ed036f916316b567152482b7e959dc754fcc4a8342333d222e05f407c"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dde2bf390d25f67908278d6f5d59e46211ef98e44108727084d4637ee70ab4f1"}, + {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7469d70d3de970b1698d47c11ebbf296a308702cbaae7fcb993944751cf985f4"}, + {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a6ba28118636a130ccbb968bc33d4684c48678695dba2590169d5ab03a45646"}, + {file = "safetensors-0.4.5-cp312-none-win32.whl", hash = "sha256:c859c7ed90b0047f58ee27751c8e56951452ed36a67afee1b0a87847d065eec6"}, + {file = "safetensors-0.4.5-cp312-none-win_amd64.whl", hash = "sha256:b5a8810ad6a6f933fff6c276eae92c1da217b39b4d8b1bc1c0b8af2d270dc532"}, + {file = "safetensors-0.4.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:25e5f8e2e92a74f05b4ca55686234c32aac19927903792b30ee6d7bd5653d54e"}, + {file = "safetensors-0.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:81efb124b58af39fcd684254c645e35692fea81c51627259cdf6d67ff4458916"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:585f1703a518b437f5103aa9cf70e9bd437cb78eea9c51024329e4fb8a3e3679"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b99fbf72e3faf0b2f5f16e5e3458b93b7d0a83984fe8d5364c60aa169f2da89"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b17b299ca9966ca983ecda1c0791a3f07f9ca6ab5ded8ef3d283fff45f6bcd5f"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76ded72f69209c9780fdb23ea89e56d35c54ae6abcdec67ccb22af8e696e449a"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2783956926303dcfeb1de91a4d1204cd4089ab441e622e7caee0642281109db3"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d94581aab8c6b204def4d7320f07534d6ee34cd4855688004a4354e63b639a35"}, + {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:67e1e7cb8678bb1b37ac48ec0df04faf689e2f4e9e81e566b5c63d9f23748523"}, + {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142"}, + {file = "safetensors-0.4.5-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:77d9b228da8374c7262046a36c1f656ba32a93df6cc51cd4453af932011e77f1"}, + {file = "safetensors-0.4.5-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:500cac01d50b301ab7bb192353317035011c5ceeef0fca652f9f43c000bb7f8d"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75331c0c746f03158ded32465b7d0b0e24c5a22121743662a2393439c43a45cf"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670e95fe34e0d591d0529e5e59fd9d3d72bc77b1444fcaa14dccda4f36b5a38b"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:098923e2574ff237c517d6e840acada8e5b311cb1fa226019105ed82e9c3b62f"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13ca0902d2648775089fa6a0c8fc9e6390c5f8ee576517d33f9261656f851e3f"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f0032bedc869c56f8d26259fe39cd21c5199cd57f2228d817a0e23e8370af25"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4b15f51b4f8f2a512341d9ce3475cacc19c5fdfc5db1f0e19449e75f95c7dc8"}, + {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f6594d130d0ad933d885c6a7b75c5183cb0e8450f799b80a39eae2b8508955eb"}, + {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:60c828a27e852ded2c85fc0f87bf1ec20e464c5cd4d56ff0e0711855cc2e17f8"}, + {file = "safetensors-0.4.5-cp37-none-win32.whl", hash = "sha256:6d3de65718b86c3eeaa8b73a9c3d123f9307a96bbd7be9698e21e76a56443af5"}, + {file = "safetensors-0.4.5-cp37-none-win_amd64.whl", hash = "sha256:5a2d68a523a4cefd791156a4174189a4114cf0bf9c50ceb89f261600f3b2b81a"}, + {file = "safetensors-0.4.5-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:e7a97058f96340850da0601a3309f3d29d6191b0702b2da201e54c6e3e44ccf0"}, + {file = "safetensors-0.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:63bfd425e25f5c733f572e2246e08a1c38bd6f2e027d3f7c87e2e43f228d1345"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3664ac565d0e809b0b929dae7ccd74e4d3273cd0c6d1220c6430035befb678e"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:313514b0b9b73ff4ddfb4edd71860696dbe3c1c9dc4d5cc13dbd74da283d2cbf"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31fa33ee326f750a2f2134a6174773c281d9a266ccd000bd4686d8021f1f3dac"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09566792588d77b68abe53754c9f1308fadd35c9f87be939e22c623eaacbed6b"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309aaec9b66cbf07ad3a2e5cb8a03205663324fea024ba391594423d0f00d9fe"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:53946c5813b8f9e26103c5efff4a931cc45d874f45229edd68557ffb35ffb9f8"}, + {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:868f9df9e99ad1e7f38c52194063a982bc88fedc7d05096f4f8160403aaf4bd6"}, + {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9cc9449bd0b0bc538bd5e268221f0c5590bc5c14c1934a6ae359d44410dc68c4"}, + {file = "safetensors-0.4.5-cp38-none-win32.whl", hash = "sha256:83c4f13a9e687335c3928f615cd63a37e3f8ef072a3f2a0599fa09f863fb06a2"}, + {file = "safetensors-0.4.5-cp38-none-win_amd64.whl", hash = "sha256:b98d40a2ffa560653f6274e15b27b3544e8e3713a44627ce268f419f35c49478"}, + {file = "safetensors-0.4.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cf727bb1281d66699bef5683b04d98c894a2803442c490a8d45cd365abfbdeb2"}, + {file = "safetensors-0.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96f1d038c827cdc552d97e71f522e1049fef0542be575421f7684756a748e457"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:139fbee92570ecea774e6344fee908907db79646d00b12c535f66bc78bd5ea2c"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c36302c1c69eebb383775a89645a32b9d266878fab619819ce660309d6176c9b"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d641f5b8149ea98deb5ffcf604d764aad1de38a8285f86771ce1abf8e74c4891"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b4db6a61d968de73722b858038c616a1bebd4a86abe2688e46ca0cc2d17558f2"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b75a616e02f21b6f1d5785b20cecbab5e2bd3f6358a90e8925b813d557666ec1"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:788ee7d04cc0e0e7f944c52ff05f52a4415b312f5efd2ee66389fb7685ee030c"}, + {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87bc42bd04fd9ca31396d3ca0433db0be1411b6b53ac5a32b7845a85d01ffc2e"}, + {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4037676c86365a721a8c9510323a51861d703b399b78a6b4486a54a65a975fca"}, + {file = "safetensors-0.4.5-cp39-none-win32.whl", hash = "sha256:1500418454529d0ed5c1564bda376c4ddff43f30fce9517d9bee7bcce5a8ef50"}, + {file = "safetensors-0.4.5-cp39-none-win_amd64.whl", hash = "sha256:9d1a94b9d793ed8fe35ab6d5cea28d540a46559bafc6aae98f30ee0867000cab"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fdadf66b5a22ceb645d5435a0be7a0292ce59648ca1d46b352f13cff3ea80410"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d42ffd4c2259f31832cb17ff866c111684c87bd930892a1ba53fed28370c918c"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd8a1f6d2063a92cd04145c7fd9e31a1c7d85fbec20113a14b487563fdbc0597"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951d2fcf1817f4fb0ef0b48f6696688a4e852a95922a042b3f96aaa67eedc920"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ac85d9a8c1af0e3132371d9f2d134695a06a96993c2e2f0bbe25debb9e3f67a"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e3cec4a29eb7fe8da0b1c7988bc3828183080439dd559f720414450de076fcab"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c7db3006a4915151ce1913652e907cdede299b974641a83fbc092102ac41b644"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f68bf99ea970960a237f416ea394e266e0361895753df06e3e06e6ea7907d98b"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8158938cf3324172df024da511839d373c40fbfaa83e9abf467174b2910d7b4c"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:540ce6c4bf6b58cb0fd93fa5f143bc0ee341c93bb4f9287ccd92cf898cc1b0dd"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bfeaa1a699c6b9ed514bd15e6a91e74738b71125a9292159e3d6b7f0a53d2cde"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:01c8f00da537af711979e1b42a69a8ec9e1d7112f208e0e9b8a35d2c381085ef"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a0dd565f83b30f2ca79b5d35748d0d99dd4b3454f80e03dfb41f0038e3bdf180"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:023b6e5facda76989f4cba95a861b7e656b87e225f61811065d5c501f78cdb3f"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9633b663393d5796f0b60249549371e392b75a0b955c07e9c6f8708a87fc841f"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78dd8adfb48716233c45f676d6e48534d34b4bceb50162c13d1f0bdf6f78590a"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8e8deb16c4321d61ae72533b8451ec4a9af8656d1c61ff81aa49f966406e4b68"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:52452fa5999dc50c4decaf0c53aa28371f7f1e0fe5c2dd9129059fbe1e1599c7"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d5f23198821e227cfc52d50fa989813513db381255c6d100927b012f0cfec63d"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f4beb84b6073b1247a773141a6331117e35d07134b3bb0383003f39971d414bb"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:68814d599d25ed2fdd045ed54d370d1d03cf35e02dce56de44c651f828fb9b7b"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0b6453c54c57c1781292c46593f8a37254b8b99004c68d6c3ce229688931a22"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:adaa9c6dead67e2dd90d634f89131e43162012479d86e25618e821a03d1eb1dc"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73e7d408e9012cd17511b382b43547850969c7979efc2bc353f317abaf23c84c"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:775409ce0fcc58b10773fdb4221ed1eb007de10fe7adbdf8f5e8a56096b6f0bc"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:834001bed193e4440c4a3950a31059523ee5090605c907c66808664c932b549c"}, + {file = "safetensors-0.4.5.tar.gz", hash = "sha256:d73de19682deabb02524b3d5d1f8b3aaba94c72f1bbfc7911b9b9d5d391c0310"}, +] + +[package.extras] +all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"] +dev = ["safetensors[all]"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"] +mlx = ["mlx (>=0.0.9)"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"] +pinned-tf = ["safetensors[numpy]", "tensorflow (==2.11.0)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] +torch = ["safetensors[numpy]", "torch (>=1.10)"] + +[[package]] +name = "scikit-learn" +version = "1.5.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scipy" +version = "1.13.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, + {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, + {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, + {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, + {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, + {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, + {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, + {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, + {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, + {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, + {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, +] + +[package.dependencies] +numpy = ">=1.22.4,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + +[[package]] +name = "sentencepiece" +version = "0.2.0" +description = "SentencePiece python wrapper" +optional = false +python-versions = "*" +files = [ + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"}, + {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"}, +] + +[[package]] +name = "setuptools" +version = "75.2.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, + {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "sympy" +version = "1.13.3" +description = "Computer algebra system (CAS) in Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73"}, + {file = "sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9"}, +] + +[package.dependencies] +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] + +[[package]] +name = "texttable" +version = "1.7.0" +description = "module to create simple ASCII tables" +optional = true +python-versions = "*" +files = [ + {file = "texttable-1.7.0-py2.py3-none-any.whl", hash = "sha256:72227d592c82b3d7f672731ae73e4d1f88cd8e2ef5b075a7a7f01a23a3743917"}, + {file = "texttable-1.7.0.tar.gz", hash = "sha256:2d2068fb55115807d3ac77a4ca68fa48803e84ebb0ee2340f858107a36522638"}, +] + +[[package]] +name = "tokenizers" +version = "0.20.1" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tokenizers-0.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:439261da7c0a5c88bda97acb284d49fbdaf67e9d3b623c0bfd107512d22787a9"}, + {file = "tokenizers-0.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:03dae629d99068b1ea5416d50de0fea13008f04129cc79af77a2a6392792d93c"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b61f561f329ffe4b28367798b89d60c4abf3f815d37413b6352bc6412a359867"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec870fce1ee5248a10be69f7a8408a234d6f2109f8ea827b4f7ecdbf08c9fd15"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d388d1ea8b7447da784e32e3b86a75cce55887e3b22b31c19d0b186b1c677800"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:299c85c1d21135bc01542237979bf25c32efa0d66595dd0069ae259b97fb2dbe"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e96f6c14c9752bb82145636b614d5a78e9cde95edfbe0a85dad0dd5ddd6ec95c"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9e95ad49c932b80abfbfeaf63b155761e695ad9f8a58c52a47d962d76e310f"}, + {file = "tokenizers-0.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f22dee205329a636148c325921c73cf3e412e87d31f4d9c3153b302a0200057b"}, + {file = "tokenizers-0.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2ffd9a8895575ac636d44500c66dffaef133823b6b25067604fa73bbc5ec09d"}, + {file = "tokenizers-0.20.1-cp310-none-win32.whl", hash = "sha256:2847843c53f445e0f19ea842a4e48b89dd0db4e62ba6e1e47a2749d6ec11f50d"}, + {file = "tokenizers-0.20.1-cp310-none-win_amd64.whl", hash = "sha256:f9aa93eacd865f2798b9e62f7ce4533cfff4f5fbd50c02926a78e81c74e432cd"}, + {file = "tokenizers-0.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4a717dcb08f2dabbf27ae4b6b20cbbb2ad7ed78ce05a829fae100ff4b3c7ff15"}, + {file = "tokenizers-0.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f84dad1ff1863c648d80628b1b55353d16303431283e4efbb6ab1af56a75832"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:929c8f3afa16a5130a81ab5079c589226273ec618949cce79b46d96e59a84f61"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d10766473954397e2d370f215ebed1cc46dcf6fd3906a2a116aa1d6219bfedc3"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9300fac73ddc7e4b0330acbdda4efaabf74929a4a61e119a32a181f534a11b47"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ecaf7b0e39caeb1aa6dd6e0975c405716c82c1312b55ac4f716ef563a906969"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5170be9ec942f3d1d317817ced8d749b3e1202670865e4fd465e35d8c259de83"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f1ae08fa9aea5891cbd69df29913e11d3841798e0bfb1ff78b78e4e7ea0a4"}, + {file = "tokenizers-0.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ee86d4095d3542d73579e953c2e5e07d9321af2ffea6ecc097d16d538a2dea16"}, + {file = "tokenizers-0.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:86dcd08da163912e17b27bbaba5efdc71b4fbffb841530fdb74c5707f3c49216"}, + {file = "tokenizers-0.20.1-cp311-none-win32.whl", hash = "sha256:9af2dc4ee97d037bc6b05fa4429ddc87532c706316c5e11ce2f0596dfcfa77af"}, + {file = "tokenizers-0.20.1-cp311-none-win_amd64.whl", hash = "sha256:899152a78b095559c287b4c6d0099469573bb2055347bb8154db106651296f39"}, + {file = "tokenizers-0.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:407ab666b38e02228fa785e81f7cf79ef929f104bcccf68a64525a54a93ceac9"}, + {file = "tokenizers-0.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f13a2d16032ebc8bd812eb8099b035ac65887d8f0c207261472803b9633cf3e"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e98eee4dca22849fbb56a80acaa899eec5b72055d79637dd6aa15d5e4b8628c9"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47c1bcdd61e61136087459cb9e0b069ff23b5568b008265e5cbc927eae3387ce"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:128c1110e950534426e2274837fc06b118ab5f2fa61c3436e60e0aada0ccfd67"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2e2d47a819d2954f2c1cd0ad51bb58ffac6f53a872d5d82d65d79bf76b9896d"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bdd67a0e3503a9a7cf8bc5a4a49cdde5fa5bada09a51e4c7e1c73900297539bd"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b93d2e26d04da337ac407acec8b5d081d8d135e3e5066a88edd5bdb5aff89"}, + {file = "tokenizers-0.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0c6a796ddcd9a19ad13cf146997cd5895a421fe6aec8fd970d69f9117bddb45c"}, + {file = "tokenizers-0.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3ea919687aa7001a8ff1ba36ac64f165c4e89035f57998fa6cedcfd877be619d"}, + {file = "tokenizers-0.20.1-cp312-none-win32.whl", hash = "sha256:6d3ac5c1f48358ffe20086bf065e843c0d0a9fce0d7f0f45d5f2f9fba3609ca5"}, + {file = "tokenizers-0.20.1-cp312-none-win_amd64.whl", hash = "sha256:b0874481aea54a178f2bccc45aa2d0c99cd3f79143a0948af6a9a21dcc49173b"}, + {file = "tokenizers-0.20.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:96af92e833bd44760fb17f23f402e07a66339c1dcbe17d79a9b55bb0cc4f038e"}, + {file = "tokenizers-0.20.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:65f34e5b731a262dfa562820818533c38ce32a45864437f3d9c82f26c139ca7f"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17f98fccb5c12ab1ce1f471731a9cd86df5d4bd2cf2880c5a66b229802d96145"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8c0fc3542cf9370bf92c932eb71bdeb33d2d4aeeb4126d9fd567b60bd04cb30"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b39356df4575d37f9b187bb623aab5abb7b62c8cb702867a1768002f814800c"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfdad27b0e50544f6b838895a373db6114b85112ba5c0cefadffa78d6daae563"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:094663dd0e85ee2e573126918747bdb40044a848fde388efb5b09d57bc74c680"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14e4cf033a2aa207d7ac790e91adca598b679999710a632c4a494aab0fc3a1b2"}, + {file = "tokenizers-0.20.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9310951c92c9fb91660de0c19a923c432f110dbfad1a2d429fbc44fa956bf64f"}, + {file = "tokenizers-0.20.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05e41e302c315bd2ed86c02e917bf03a6cf7d2f652c9cee1a0eb0d0f1ca0d32c"}, + {file = "tokenizers-0.20.1-cp37-none-win32.whl", hash = "sha256:212231ab7dfcdc879baf4892ca87c726259fa7c887e1688e3f3cead384d8c305"}, + {file = "tokenizers-0.20.1-cp37-none-win_amd64.whl", hash = "sha256:896195eb9dfdc85c8c052e29947169c1fcbe75a254c4b5792cdbd451587bce85"}, + {file = "tokenizers-0.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:741fb22788482d09d68e73ece1495cfc6d9b29a06c37b3df90564a9cfa688e6d"}, + {file = "tokenizers-0.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10be14ebd8082086a342d969e17fc2d6edc856c59dbdbddd25f158fa40eaf043"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:514cf279b22fa1ae0bc08e143458c74ad3b56cd078b319464959685a35c53d5e"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a647c5b7cb896d6430cf3e01b4e9a2d77f719c84cefcef825d404830c2071da2"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cdf379219e1e1dd432091058dab325a2e6235ebb23e0aec8d0508567c90cd01"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ba72260449e16c4c2f6f3252823b059fbf2d31b32617e582003f2b18b415c39"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:910b96ed87316e4277b23c7bcaf667ce849c7cc379a453fa179e7e09290eeb25"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e53975a6694428a0586534cc1354b2408d4e010a3103117f617cbb550299797c"}, + {file = "tokenizers-0.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:07c4b7be58da142b0730cc4e5fd66bb7bf6f57f4986ddda73833cd39efef8a01"}, + {file = "tokenizers-0.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b605c540753e62199bf15cf69c333e934077ef2350262af2ccada46026f83d1c"}, + {file = "tokenizers-0.20.1-cp38-none-win32.whl", hash = "sha256:88b3bc76ab4db1ab95ead623d49c95205411e26302cf9f74203e762ac7e85685"}, + {file = "tokenizers-0.20.1-cp38-none-win_amd64.whl", hash = "sha256:d412a74cf5b3f68a90c615611a5aa4478bb303d1c65961d22db45001df68afcb"}, + {file = "tokenizers-0.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a25dcb2f41a0a6aac31999e6c96a75e9152fa0127af8ece46c2f784f23b8197a"}, + {file = "tokenizers-0.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a12c3cebb8c92e9c35a23ab10d3852aee522f385c28d0b4fe48c0b7527d59762"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02e18da58cf115b7c40de973609c35bde95856012ba42a41ee919c77935af251"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f326a1ac51ae909b9760e34671c26cd0dfe15662f447302a9d5bb2d872bab8ab"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b4872647ea6f25224e2833b044b0b19084e39400e8ead3cfe751238b0802140"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce6238a3311bb8e4c15b12600927d35c267b92a52c881ef5717a900ca14793f7"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57b7a8880b208866508b06ce365dc631e7a2472a3faa24daa430d046fb56c885"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a908c69c2897a68f412aa05ba38bfa87a02980df70f5a72fa8490479308b1f2d"}, + {file = "tokenizers-0.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:da1001aa46f4490099c82e2facc4fbc06a6a32bf7de3918ba798010954b775e0"}, + {file = "tokenizers-0.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:42c097390e2f0ed0a5c5d569e6669dd4e9fff7b31c6a5ce6e9c66a61687197de"}, + {file = "tokenizers-0.20.1-cp39-none-win32.whl", hash = "sha256:3d4d218573a3d8b121a1f8c801029d70444ffb6d8f129d4cca1c7b672ee4a24c"}, + {file = "tokenizers-0.20.1-cp39-none-win_amd64.whl", hash = "sha256:37d1e6f616c84fceefa7c6484a01df05caf1e207669121c66213cb5b2911d653"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48689da7a395df41114f516208d6550e3e905e1239cc5ad386686d9358e9cef0"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:712f90ea33f9bd2586b4a90d697c26d56d0a22fd3c91104c5858c4b5b6489a79"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:359eceb6a620c965988fc559cebc0a98db26713758ec4df43fb76d41486a8ed5"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d3caf244ce89d24c87545aafc3448be15870096e796c703a0d68547187192e1"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03b03cf8b9a32254b1bf8a305fb95c6daf1baae0c1f93b27f2b08c9759f41dee"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:218e5a3561561ea0f0ef1559c6d95b825308dbec23fb55b70b92589e7ff2e1e8"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f40df5e0294a95131cc5f0e0eb91fe86d88837abfbee46b9b3610b09860195a7"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:08aaa0d72bb65058e8c4b0455f61b840b156c557e2aca57627056624c3a93976"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:998700177b45f70afeb206ad22c08d9e5f3a80639dae1032bf41e8cbc4dada4b"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62f7fbd3c2c38b179556d879edae442b45f68312019c3a6013e56c3947a4e648"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31e87fca4f6bbf5cc67481b562147fe932f73d5602734de7dd18a8f2eee9c6dd"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:956f21d359ae29dd51ca5726d2c9a44ffafa041c623f5aa33749da87cfa809b9"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1fbbaf17a393c78d8aedb6a334097c91cb4119a9ced4764ab8cfdc8d254dc9f9"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ebe63e31f9c1a970c53866d814e35ec2ec26fda03097c486f82f3891cee60830"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:81970b80b8ac126910295f8aab2d7ef962009ea39e0d86d304769493f69aaa1e"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130e35e76f9337ed6c31be386e75d4925ea807055acf18ca1a9b0eec03d8fe23"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd28a8614f5c82a54ab2463554e84ad79526c5184cf4573bbac2efbbbcead457"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9041ee665d0fa7f5c4ccf0f81f5e6b7087f797f85b143c094126fc2611fec9d0"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:62eb9daea2a2c06bcd8113a5824af8ef8ee7405d3a71123ba4d52c79bb3d9f1a"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f861889707b54a9ab1204030b65fd6c22bdd4a95205deec7994dc22a8baa2ea4"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:89d5c337d74ea6e5e7dc8af124cf177be843bbb9ca6e58c01f75ea103c12c8a9"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:0b7f515c83397e73292accdbbbedc62264e070bae9682f06061e2ddce67cacaf"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e0305fc1ec6b1e5052d30d9c1d5c807081a7bd0cae46a33d03117082e91908c"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc611e6ac0fa00a41de19c3bf6391a05ea201d2d22b757d63f5491ec0e67faa"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5ffe0d7f7bfcfa3b2585776ecf11da2e01c317027c8573c78ebcb8985279e23"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e7edb8ec12c100d5458d15b1e47c0eb30ad606a05641f19af7563bc3d1608c14"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:de291633fb9303555793cc544d4a86e858da529b7d0b752bcaf721ae1d74b2c9"}, + {file = "tokenizers-0.20.1.tar.gz", hash = "sha256:84edcc7cdeeee45ceedb65d518fffb77aec69311c9c8e30f77ad84da3025f002"}, +] + +[package.dependencies] +huggingface-hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] + +[[package]] +name = "tomli" +version = "2.0.2" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, + {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, +] + +[[package]] +name = "tqdm" +version = "4.66.5" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, + {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "transformers" +version = "4.45.2" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "transformers-4.45.2-py3-none-any.whl", hash = "sha256:c551b33660cfc815bae1f9f097ecfd1e65be623f13c6ee0dda372bd881460210"}, + {file = "transformers-4.45.2.tar.gz", hash = "sha256:72bc390f6b203892561f05f86bbfaa0e234aab8e927a83e62b9d92ea7e3ae101"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.23.2,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.4.1" +sentencepiece = {version = ">=0.1.91,<0.1.92 || >0.1.92", optional = true, markers = "extra == \"sentencepiece\""} +tokenizers = ">=0.20,<0.21" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.26.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +benchmark = ["optimum-benchmark (>=0.3.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.20,<0.21)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6,<0.15.0)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +ruff = ["ruff (==0.5.1)"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +tiktoken = ["blobfile", "tiktoken"] +timm = ["timm (<=0.9.16)"] +tokenizers = ["tokenizers (>=0.20,<0.21)"] +torch = ["accelerate (>=0.26.0)", "torch"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.20,<0.21)", "torch", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + +[[package]] +name = "triton" +version = "3.0.0" +description = "A language and compiler for custom Deep Learning operations" +optional = true +python-versions = "*" +files = [ + {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, + {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, + {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, + {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, + {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + +[[package]] +name = "typer" +version = "0.6.1" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.6.1-py3-none-any.whl", hash = "sha256:54b19e5df18654070a82f8c2aa1da456a4ac16a2a83e6dcd9f170e291c56338e"}, + {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + +[[package]] +name = "tzdata" +version = "2024.2" +description = "Provider of IANA time zone data" +optional = true +python-versions = ">=2" +files = [ + {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, + {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, +] + +[[package]] +name = "urllib3" +version = "2.2.3" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.8" +files = [ + {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, + {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +h2 = ["h2 (>=4,<5)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + +[[package]] +name = "wrapt" +version = "1.16.0" +description = "Module for decorators, wrappers and monkey patching." +optional = false +python-versions = ">=3.6" +files = [ + {file = "wrapt-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4"}, + {file = "wrapt-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136"}, + {file = "wrapt-1.16.0-cp310-cp310-win32.whl", hash = "sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d"}, + {file = "wrapt-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2"}, + {file = "wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09"}, + {file = "wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d"}, + {file = "wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362"}, + {file = "wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89"}, + {file = "wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b"}, + {file = "wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c"}, + {file = "wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc"}, + {file = "wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8"}, + {file = "wrapt-1.16.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465"}, + {file = "wrapt-1.16.0-cp36-cp36m-win32.whl", hash = "sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e"}, + {file = "wrapt-1.16.0-cp36-cp36m-win_amd64.whl", hash = "sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966"}, + {file = "wrapt-1.16.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c"}, + {file = "wrapt-1.16.0-cp37-cp37m-win32.whl", hash = "sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c"}, + {file = "wrapt-1.16.0-cp37-cp37m-win_amd64.whl", hash = "sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00"}, + {file = "wrapt-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0"}, + {file = "wrapt-1.16.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6"}, + {file = "wrapt-1.16.0-cp38-cp38-win32.whl", hash = "sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b"}, + {file = "wrapt-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41"}, + {file = "wrapt-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2"}, + {file = "wrapt-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537"}, + {file = "wrapt-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3"}, + {file = "wrapt-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35"}, + {file = "wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1"}, + {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, +] + +[[package]] +name = "xxhash" +version = "3.5.0" +description = "Python binding for xxHash" +optional = true +python-versions = ">=3.7" +files = [ + {file = "xxhash-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ece616532c499ee9afbb83078b1b952beffef121d989841f7f4b3dc5ac0fd212"}, + {file = "xxhash-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3171f693dbc2cef6477054a665dc255d996646b4023fe56cb4db80e26f4cc520"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c5d3e570ef46adaf93fc81b44aca6002b5a4d8ca11bd0580c07eac537f36680"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7cb29a034301e2982df8b1fe6328a84f4b676106a13e9135a0d7e0c3e9f806da"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0d307d27099bb0cbeea7260eb39ed4fdb99c5542e21e94bb6fd29e49c57a23"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0342aafd421795d740e514bc9858ebddfc705a75a8c5046ac56d85fe97bf196"}, + {file = "xxhash-3.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dbbd9892c5ebffeca1ed620cf0ade13eb55a0d8c84e0751a6653adc6ac40d0c"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4cc2d67fdb4d057730c75a64c5923abfa17775ae234a71b0200346bfb0a7f482"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ec28adb204b759306a3d64358a5e5c07d7b1dd0ccbce04aa76cb9377b7b70296"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1328f6d8cca2b86acb14104e381225a3d7b42c92c4b86ceae814e5c400dbb415"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8d47ebd9f5d9607fd039c1fbf4994e3b071ea23eff42f4ecef246ab2b7334198"}, + {file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b96d559e0fcddd3343c510a0fe2b127fbff16bf346dd76280b82292567523442"}, + {file = "xxhash-3.5.0-cp310-cp310-win32.whl", hash = "sha256:61c722ed8d49ac9bc26c7071eeaa1f6ff24053d553146d5df031802deffd03da"}, + {file = "xxhash-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:9bed5144c6923cc902cd14bb8963f2d5e034def4486ab0bbe1f58f03f042f9a9"}, + {file = "xxhash-3.5.0-cp310-cp310-win_arm64.whl", hash = "sha256:893074d651cf25c1cc14e3bea4fceefd67f2921b1bb8e40fcfeba56820de80c6"}, + {file = "xxhash-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02c2e816896dc6f85922ced60097bcf6f008dedfc5073dcba32f9c8dd786f3c1"}, + {file = "xxhash-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6027dcd885e21581e46d3c7f682cfb2b870942feeed58a21c29583512c3f09f8"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1308fa542bbdbf2fa85e9e66b1077eea3a88bef38ee8a06270b4298a7a62a166"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c28b2fdcee797e1c1961cd3bcd3d545cab22ad202c846235197935e1df2f8ef7"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:924361811732ddad75ff23e90efd9ccfda4f664132feecb90895bade6a1b4623"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89997aa1c4b6a5b1e5b588979d1da048a3c6f15e55c11d117a56b75c84531f5a"}, + {file = "xxhash-3.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:685c4f4e8c59837de103344eb1c8a3851f670309eb5c361f746805c5471b8c88"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbd2ecfbfee70bc1a4acb7461fa6af7748ec2ab08ac0fa298f281c51518f982c"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:25b5a51dc3dfb20a10833c8eee25903fd2e14059e9afcd329c9da20609a307b2"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a8fb786fb754ef6ff8c120cb96629fb518f8eb5a61a16aac3a979a9dbd40a084"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a905ad00ad1e1c34fe4e9d7c1d949ab09c6fa90c919860c1534ff479f40fd12d"}, + {file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:963be41bcd49f53af6d795f65c0da9b4cc518c0dd9c47145c98f61cb464f4839"}, + {file = "xxhash-3.5.0-cp311-cp311-win32.whl", hash = "sha256:109b436096d0a2dd039c355fa3414160ec4d843dfecc64a14077332a00aeb7da"}, + {file = "xxhash-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b702f806693201ad6c0a05ddbbe4c8f359626d0b3305f766077d51388a6bac58"}, + {file = "xxhash-3.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:c4dcb4120d0cc3cc448624147dba64e9021b278c63e34a38789b688fd0da9bf3"}, + {file = "xxhash-3.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:14470ace8bd3b5d51318782cd94e6f94431974f16cb3b8dc15d52f3b69df8e00"}, + {file = "xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:59aa1203de1cb96dbeab595ded0ad0c0056bb2245ae11fac11c0ceea861382b9"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08424f6648526076e28fae6ea2806c0a7d504b9ef05ae61d196d571e5c879c84"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61a1ff00674879725b194695e17f23d3248998b843eb5e933007ca743310f793"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2f2c61bee5844d41c3eb015ac652a0229e901074951ae48581d58bfb2ba01be"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d32a592cac88d18cc09a89172e1c32d7f2a6e516c3dfde1b9adb90ab5df54a6"}, + {file = "xxhash-3.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70dabf941dede727cca579e8c205e61121afc9b28516752fd65724be1355cc90"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e5d0ddaca65ecca9c10dcf01730165fd858533d0be84c75c327487c37a906a27"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e5b5e16c5a480fe5f59f56c30abdeba09ffd75da8d13f6b9b6fd224d0b4d0a2"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149b7914451eb154b3dfaa721315117ea1dac2cc55a01bfbd4df7c68c5dd683d"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:eade977f5c96c677035ff39c56ac74d851b1cca7d607ab3d8f23c6b859379cab"}, + {file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa9f547bd98f5553d03160967866a71056a60960be00356a15ecc44efb40ba8e"}, + {file = "xxhash-3.5.0-cp312-cp312-win32.whl", hash = "sha256:f7b58d1fd3551b8c80a971199543379be1cee3d0d409e1f6d8b01c1a2eebf1f8"}, + {file = "xxhash-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0cafd3a2af231b4e113fba24a65d7922af91aeb23774a8b78228e6cd785e3e"}, + {file = "xxhash-3.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:586886c7e89cb9828bcd8a5686b12e161368e0064d040e225e72607b43858ba2"}, + {file = "xxhash-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:37889a0d13b0b7d739cfc128b1c902f04e32de17b33d74b637ad42f1c55101f6"}, + {file = "xxhash-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97a662338797c660178e682f3bc180277b9569a59abfb5925e8620fba00b9fc5"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f85e0108d51092bdda90672476c7d909c04ada6923c14ff9d913c4f7dc8a3bc"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2fd827b0ba763ac919440042302315c564fdb797294d86e8cdd4578e3bc7f3"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82085c2abec437abebf457c1d12fccb30cc8b3774a0814872511f0f0562c768c"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07fda5de378626e502b42b311b049848c2ef38784d0d67b6f30bb5008642f8eb"}, + {file = "xxhash-3.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c279f0d2b34ef15f922b77966640ade58b4ccdfef1c4d94b20f2a364617a493f"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89e66ceed67b213dec5a773e2f7a9e8c58f64daeb38c7859d8815d2c89f39ad7"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bcd51708a633410737111e998ceb3b45d3dbc98c0931f743d9bb0a209033a326"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3ff2c0a34eae7df88c868be53a8dd56fbdf592109e21d4bfa092a27b0bf4a7bf"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4e28503dccc7d32e0b9817aa0cbfc1f45f563b2c995b7a66c4c8a0d232e840c7"}, + {file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a6c50017518329ed65a9e4829154626f008916d36295b6a3ba336e2458824c8c"}, + {file = "xxhash-3.5.0-cp313-cp313-win32.whl", hash = "sha256:53a068fe70301ec30d868ece566ac90d873e3bb059cf83c32e76012c889b8637"}, + {file = "xxhash-3.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:80babcc30e7a1a484eab952d76a4f4673ff601f54d5142c26826502740e70b43"}, + {file = "xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b"}, + {file = "xxhash-3.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6e5f70f6dca1d3b09bccb7daf4e087075ff776e3da9ac870f86ca316736bb4aa"}, + {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e76e83efc7b443052dd1e585a76201e40b3411fe3da7af4fe434ec51b2f163b"}, + {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33eac61d0796ca0591f94548dcfe37bb193671e0c9bcf065789b5792f2eda644"}, + {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ec70a89be933ea49222fafc3999987d7899fc676f688dd12252509434636622"}, + {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86b8e7f703ec6ff4f351cfdb9f428955859537125904aa8c963604f2e9d3e7"}, + {file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0adfbd36003d9f86c8c97110039f7539b379f28656a04097e7434d3eaf9aa131"}, + {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:63107013578c8a730419adc05608756c3fa640bdc6abe806c3123a49fb829f43"}, + {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:683b94dbd1ca67557850b86423318a2e323511648f9f3f7b1840408a02b9a48c"}, + {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5d2a01dcce81789cf4b12d478b5464632204f4c834dc2d064902ee27d2d1f0ee"}, + {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:a9d360a792cbcce2fe7b66b8d51274ec297c53cbc423401480e53b26161a290d"}, + {file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:f0b48edbebea1b7421a9c687c304f7b44d0677c46498a046079d445454504737"}, + {file = "xxhash-3.5.0-cp37-cp37m-win32.whl", hash = "sha256:7ccb800c9418e438b44b060a32adeb8393764da7441eb52aa2aa195448935306"}, + {file = "xxhash-3.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c3bc7bf8cb8806f8d1c9bf149c18708cb1c406520097d6b0a73977460ea03602"}, + {file = "xxhash-3.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:74752ecaa544657d88b1d1c94ae68031e364a4d47005a90288f3bab3da3c970f"}, + {file = "xxhash-3.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:dee1316133c9b463aa81aca676bc506d3f80d8f65aeb0bba2b78d0b30c51d7bd"}, + {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:602d339548d35a8579c6b013339fb34aee2df9b4e105f985443d2860e4d7ffaa"}, + {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:695735deeddfb35da1677dbc16a083445360e37ff46d8ac5c6fcd64917ff9ade"}, + {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1030a39ba01b0c519b1a82f80e8802630d16ab95dc3f2b2386a0b5c8ed5cbb10"}, + {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5bc08f33c4966f4eb6590d6ff3ceae76151ad744576b5fc6c4ba8edd459fdec"}, + {file = "xxhash-3.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:160e0c19ee500482ddfb5d5570a0415f565d8ae2b3fd69c5dcfce8a58107b1c3"}, + {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f1abffa122452481a61c3551ab3c89d72238e279e517705b8b03847b1d93d738"}, + {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:d5e9db7ef3ecbfc0b4733579cea45713a76852b002cf605420b12ef3ef1ec148"}, + {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:23241ff6423378a731d84864bf923a41649dc67b144debd1077f02e6249a0d54"}, + {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:82b833d5563fefd6fceafb1aed2f3f3ebe19f84760fdd289f8b926731c2e6e91"}, + {file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0a80ad0ffd78bef9509eee27b4a29e56f5414b87fb01a888353e3d5bda7038bd"}, + {file = "xxhash-3.5.0-cp38-cp38-win32.whl", hash = "sha256:50ac2184ffb1b999e11e27c7e3e70cc1139047e7ebc1aa95ed12f4269abe98d4"}, + {file = "xxhash-3.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:392f52ebbb932db566973693de48f15ce787cabd15cf6334e855ed22ea0be5b3"}, + {file = "xxhash-3.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfc8cdd7f33d57f0468b0614ae634cc38ab9202c6957a60e31d285a71ebe0301"}, + {file = "xxhash-3.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0c48b6300cd0b0106bf49169c3e0536408dfbeb1ccb53180068a18b03c662ab"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe1a92cfbaa0a1253e339ccec42dbe6db262615e52df591b68726ab10338003f"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33513d6cc3ed3b559134fb307aae9bdd94d7e7c02907b37896a6c45ff9ce51bd"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eefc37f6138f522e771ac6db71a6d4838ec7933939676f3753eafd7d3f4c40bc"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a606c8070ada8aa2a88e181773fa1ef17ba65ce5dd168b9d08038e2a61b33754"}, + {file = "xxhash-3.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42eca420c8fa072cc1dd62597635d140e78e384a79bb4944f825fbef8bfeeef6"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:604253b2143e13218ff1ef0b59ce67f18b8bd1c4205d2ffda22b09b426386898"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:6e93a5ad22f434d7876665444a97e713a8f60b5b1a3521e8df11b98309bff833"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:7a46e1d6d2817ba8024de44c4fd79913a90e5f7265434cef97026215b7d30df6"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:30eb2efe6503c379b7ab99c81ba4a779748e3830241f032ab46bd182bf5873af"}, + {file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c8aa771ff2c13dd9cda8166d685d7333d389fae30a4d2bb39d63ab5775de8606"}, + {file = "xxhash-3.5.0-cp39-cp39-win32.whl", hash = "sha256:5ed9ebc46f24cf91034544b26b131241b699edbfc99ec5e7f8f3d02d6eb7fba4"}, + {file = "xxhash-3.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:220f3f896c6b8d0316f63f16c077d52c412619e475f9372333474ee15133a558"}, + {file = "xxhash-3.5.0-cp39-cp39-win_arm64.whl", hash = "sha256:a7b1d8315d9b5e9f89eb2933b73afae6ec9597a258d52190944437158b49d38e"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2014c5b3ff15e64feecb6b713af12093f75b7926049e26a580e94dcad3c73d8c"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab81ef75003eda96239a23eda4e4543cedc22e34c373edcaf744e721a163986"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e2febf914ace002132aa09169cc572e0d8959d0f305f93d5828c4836f9bc5a6"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d3a10609c51da2a1c0ea0293fc3968ca0a18bd73838455b5bca3069d7f8e32b"}, + {file = "xxhash-3.5.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5a74f23335b9689b66eb6dbe2a931a88fcd7a4c2cc4b1cb0edba8ce381c7a1da"}, + {file = "xxhash-3.5.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2b4154c00eb22e4d543f472cfca430e7962a0f1d0f3778334f2e08a7ba59363c"}, + {file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d30bbc1644f726b825b3278764240f449d75f1a8bdda892e641d4a688b1494ae"}, + {file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa0b72f2423e2aa53077e54a61c28e181d23effeaafd73fcb9c494e60930c8e"}, + {file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:13de2b76c1835399b2e419a296d5b38dc4855385d9e96916299170085ef72f57"}, + {file = "xxhash-3.5.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:0691bfcc4f9c656bcb96cc5db94b4d75980b9d5589f2e59de790091028580837"}, + {file = "xxhash-3.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:297595fe6138d4da2c8ce9e72a04d73e58725bb60f3a19048bc96ab2ff31c692"}, + {file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc1276d369452040cbb943300dc8abeedab14245ea44056a2943183822513a18"}, + {file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2061188a1ba352fc699c82bff722f4baacb4b4b8b2f0c745d2001e56d0dfb514"}, + {file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38c384c434021e4f62b8d9ba0bc9467e14d394893077e2c66d826243025e1f81"}, + {file = "xxhash-3.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e6a4dd644d72ab316b580a1c120b375890e4c52ec392d4aef3c63361ec4d77d1"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:531af8845aaadcadf951b7e0c1345c6b9c68a990eeb74ff9acd8501a0ad6a1c9"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ce379bcaa9fcc00f19affa7773084dd09f5b59947b3fb47a1ceb0179f91aaa1"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd1b2281d01723f076df3c8188f43f2472248a6b63118b036e641243656b1b0f"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c770750cc80e8694492244bca7251385188bc5597b6a39d98a9f30e8da984e0"}, + {file = "xxhash-3.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b150b8467852e1bd844387459aa6fbe11d7f38b56e901f9f3b3e6aba0d660240"}, + {file = "xxhash-3.5.0.tar.gz", hash = "sha256:84f2caddf951c9cbf8dc2e22a89d4ccf5d86391ac6418fe81e3c67d0cf60b45f"}, +] + +[[package]] +name = "yarl" +version = "1.16.0" +description = "Yet another URL library" +optional = false +python-versions = ">=3.9" +files = [ + {file = "yarl-1.16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:32468f41242d72b87ab793a86d92f885355bcf35b3355aa650bfa846a5c60058"}, + {file = "yarl-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:234f3a3032b505b90e65b5bc6652c2329ea7ea8855d8de61e1642b74b4ee65d2"}, + {file = "yarl-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a0296040e5cddf074c7f5af4a60f3fc42c0237440df7bcf5183be5f6c802ed5"}, + {file = "yarl-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de6c14dd7c7c0badba48157474ea1f03ebee991530ba742d381b28d4f314d6f3"}, + {file = "yarl-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b140e532fe0266003c936d017c1ac301e72ee4a3fd51784574c05f53718a55d8"}, + {file = "yarl-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:019f5d58093402aa8f6661e60fd82a28746ad6d156f6c5336a70a39bd7b162b9"}, + {file = "yarl-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c42998fd1cbeb53cd985bff0e4bc25fbe55fd6eb3a545a724c1012d69d5ec84"}, + {file = "yarl-1.16.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c7c30fb38c300fe8140df30a046a01769105e4cf4282567a29b5cdb635b66c4"}, + {file = "yarl-1.16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e49e0fd86c295e743fd5be69b8b0712f70a686bc79a16e5268386c2defacaade"}, + {file = "yarl-1.16.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:b9ca7b9147eb1365c8bab03c003baa1300599575effad765e0b07dd3501ea9af"}, + {file = "yarl-1.16.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:27e11db3f1e6a51081a981509f75617b09810529de508a181319193d320bc5c7"}, + {file = "yarl-1.16.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8994c42f4ca25df5380ddf59f315c518c81df6a68fed5bb0c159c6cb6b92f120"}, + {file = "yarl-1.16.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:542fa8e09a581bcdcbb30607c7224beff3fdfb598c798ccd28a8184ffc18b7eb"}, + {file = "yarl-1.16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2bd6a51010c7284d191b79d3b56e51a87d8e1c03b0902362945f15c3d50ed46b"}, + {file = "yarl-1.16.0-cp310-cp310-win32.whl", hash = "sha256:178ccb856e265174a79f59721031060f885aca428983e75c06f78aa24b91d929"}, + {file = "yarl-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:fe8bba2545427418efc1929c5c42852bdb4143eb8d0a46b09de88d1fe99258e7"}, + {file = "yarl-1.16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d8643975a0080f361639787415a038bfc32d29208a4bf6b783ab3075a20b1ef3"}, + {file = "yarl-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:676d96bafc8c2d0039cea0cd3fd44cee7aa88b8185551a2bb93354668e8315c2"}, + {file = "yarl-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d9525f03269e64310416dbe6c68d3b23e5d34aaa8f47193a1c45ac568cecbc49"}, + {file = "yarl-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b37d5ec034e668b22cf0ce1074d6c21fd2a08b90d11b1b73139b750a8b0dd97"}, + {file = "yarl-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f32c4cb7386b41936894685f6e093c8dfaf0960124d91fe0ec29fe439e201d0"}, + {file = "yarl-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b8e265a0545637492a7e12fd7038370d66c9375a61d88c5567d0e044ded9202"}, + {file = "yarl-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:789a3423f28a5fff46fbd04e339863c169ece97c827b44de16e1a7a42bc915d2"}, + {file = "yarl-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f1d1f45e3e8d37c804dca99ab3cf4ab3ed2e7a62cd82542924b14c0a4f46d243"}, + {file = "yarl-1.16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:621280719c4c5dad4c1391160a9b88925bb8b0ff6a7d5af3224643024871675f"}, + {file = "yarl-1.16.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:ed097b26f18a1f5ff05f661dc36528c5f6735ba4ce8c9645e83b064665131349"}, + {file = "yarl-1.16.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:2f1fe2b2e3ee418862f5ebc0c0083c97f6f6625781382f828f6d4e9b614eba9b"}, + {file = "yarl-1.16.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:87dd10bc0618991c66cee0cc65fa74a45f4ecb13bceec3c62d78ad2e42b27a16"}, + {file = "yarl-1.16.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:4199db024b58a8abb2cfcedac7b1292c3ad421684571aeb622a02f242280e8d6"}, + {file = "yarl-1.16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:99a9dcd4b71dd5f5f949737ab3f356cfc058c709b4f49833aeffedc2652dac56"}, + {file = "yarl-1.16.0-cp311-cp311-win32.whl", hash = "sha256:a9394c65ae0ed95679717d391c862dece9afacd8fa311683fc8b4362ce8a410c"}, + {file = "yarl-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b9101f528ae0f8f65ac9d64dda2bb0627de8a50344b2f582779f32fda747c1d"}, + {file = "yarl-1.16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:4ffb7c129707dd76ced0a4a4128ff452cecf0b0e929f2668ea05a371d9e5c104"}, + {file = "yarl-1.16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:1a5e9d8ce1185723419c487758d81ac2bde693711947032cce600ca7c9cda7d6"}, + {file = "yarl-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d743e3118b2640cef7768ea955378c3536482d95550222f908f392167fe62059"}, + {file = "yarl-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26768342f256e6e3c37533bf9433f5f15f3e59e3c14b2409098291b3efaceacb"}, + {file = "yarl-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1b0796168b953bca6600c5f97f5ed407479889a36ad7d17183366260f29a6b9"}, + {file = "yarl-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:858728086914f3a407aa7979cab743bbda1fe2bdf39ffcd991469a370dd7414d"}, + {file = "yarl-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5570e6d47bcb03215baf4c9ad7bf7c013e56285d9d35013541f9ac2b372593e7"}, + {file = "yarl-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66ea8311422a7ba1fc79b4c42c2baa10566469fe5a78500d4e7754d6e6db8724"}, + {file = "yarl-1.16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:649bddcedee692ee8a9b7b6e38582cb4062dc4253de9711568e5620d8707c2a3"}, + {file = "yarl-1.16.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:3a91654adb7643cb21b46f04244c5a315a440dcad63213033826549fa2435f71"}, + {file = "yarl-1.16.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b439cae82034ade094526a8f692b9a2b5ee936452de5e4c5f0f6c48df23f8604"}, + {file = "yarl-1.16.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:571f781ae8ac463ce30bacebfaef2c6581543776d5970b2372fbe31d7bf31a07"}, + {file = "yarl-1.16.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:aa7943f04f36d6cafc0cf53ea89824ac2c37acbdb4b316a654176ab8ffd0f968"}, + {file = "yarl-1.16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1a5cf32539373ff39d97723e39a9283a7277cbf1224f7aef0c56c9598b6486c3"}, + {file = "yarl-1.16.0-cp312-cp312-win32.whl", hash = "sha256:a5b6c09b9b4253d6a208b0f4a2f9206e511ec68dce9198e0fbec4f160137aa67"}, + {file = "yarl-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:1208ca14eed2fda324042adf8d6c0adf4a31522fa95e0929027cd487875f0240"}, + {file = "yarl-1.16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5ace0177520bd4caa99295a9b6fb831d0e9a57d8e0501a22ffaa61b4c024283"}, + {file = "yarl-1.16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7118bdb5e3ed81acaa2095cba7ec02a0fe74b52a16ab9f9ac8e28e53ee299732"}, + {file = "yarl-1.16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:38fec8a2a94c58bd47c9a50a45d321ab2285ad133adefbbadf3012c054b7e656"}, + {file = "yarl-1.16.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8791d66d81ee45866a7bb15a517b01a2bcf583a18ebf5d72a84e6064c417e64b"}, + {file = "yarl-1.16.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cf936ba67bc6c734f3aa1c01391da74ab7fc046a9f8bbfa230b8393b90cf472"}, + {file = "yarl-1.16.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1aab176dd55b59f77a63b27cffaca67d29987d91a5b615cbead41331e6b7428"}, + {file = "yarl-1.16.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:995d0759004c08abd5d1b81300a91d18c8577c6389300bed1c7c11675105a44d"}, + {file = "yarl-1.16.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1bc22e00edeb068f71967ab99081e9406cd56dbed864fc3a8259442999d71552"}, + {file = "yarl-1.16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:35b4f7842154176523e0a63c9b871168c69b98065d05a4f637fce342a6a2693a"}, + {file = "yarl-1.16.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:7ace71c4b7a0c41f317ae24be62bb61e9d80838d38acb20e70697c625e71f120"}, + {file = "yarl-1.16.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:8f639e3f5795a6568aa4f7d2ac6057c757dcd187593679f035adbf12b892bb00"}, + {file = "yarl-1.16.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:e8be3aff14f0120ad049121322b107f8a759be76a6a62138322d4c8a337a9e2c"}, + {file = "yarl-1.16.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:122d8e7986043d0549e9eb23c7fd23be078be4b70c9eb42a20052b3d3149c6f2"}, + {file = "yarl-1.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0fd9c227990f609c165f56b46107d0bc34553fe0387818c42c02f77974402c36"}, + {file = "yarl-1.16.0-cp313-cp313-win32.whl", hash = "sha256:595ca5e943baed31d56b33b34736461a371c6ea0038d3baec399949dd628560b"}, + {file = "yarl-1.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:921b81b8d78f0e60242fb3db615ea3f368827a76af095d5a69f1c3366db3f596"}, + {file = "yarl-1.16.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab2b2ac232110a1fdb0d3ffcd087783edd3d4a6ced432a1bf75caf7b7be70916"}, + {file = "yarl-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f8713717a09acbfee7c47bfc5777e685539fefdd34fa72faf504c8be2f3df4e"}, + {file = "yarl-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cdcffe1dbcb4477d2b4202f63cd972d5baa155ff5a3d9e35801c46a415b7f71a"}, + {file = "yarl-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a91217208306d82357c67daeef5162a41a28c8352dab7e16daa82e3718852a7"}, + {file = "yarl-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ab3ed42c78275477ea8e917491365e9a9b69bb615cb46169020bd0aa5e2d6d3"}, + {file = "yarl-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:707ae579ccb3262dfaef093e202b4c3fb23c3810e8df544b1111bd2401fd7b09"}, + {file = "yarl-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad7a852d1cd0b8d8b37fc9d7f8581152add917a98cfe2ea6e241878795f917ae"}, + {file = "yarl-1.16.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3f1cc3d3d4dc574bebc9b387f6875e228ace5748a7c24f49d8f01ac1bc6c31b"}, + {file = "yarl-1.16.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5ff96da263740779b0893d02b718293cc03400c3a208fc8d8cd79d9b0993e532"}, + {file = "yarl-1.16.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:3d375a19ba2bfe320b6d873f3fb165313b002cef8b7cc0a368ad8b8a57453837"}, + {file = "yarl-1.16.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:62c7da0ad93a07da048b500514ca47b759459ec41924143e2ddb5d7e20fd3db5"}, + {file = "yarl-1.16.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:147b0fcd0ee33b4b5f6edfea80452d80e419e51b9a3f7a96ce98eaee145c1581"}, + {file = "yarl-1.16.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:504e1fe1cc4f170195320eb033d2b0ccf5c6114ce5bf2f617535c01699479bca"}, + {file = "yarl-1.16.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:bdcf667a5dec12a48f669e485d70c54189f0639c2157b538a4cffd24a853624f"}, + {file = "yarl-1.16.0-cp39-cp39-win32.whl", hash = "sha256:e9951afe6557c75a71045148890052cb942689ee4c9ec29f5436240e1fcc73b7"}, + {file = "yarl-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:7d7aaa8ff95d0840e289423e7dc35696c2b058d635f945bf05b5cd633146b027"}, + {file = "yarl-1.16.0-py3-none-any.whl", hash = "sha256:e6980a558d8461230c457218bd6c92dfc1d10205548215c2c21d79dc8d0a96f3"}, + {file = "yarl-1.16.0.tar.gz", hash = "sha256:b6f687ced5510a9a2474bbae96a4352e5ace5fa34dc44a217b0537fec1db00b4"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" +propcache = ">=0.2.0" + +[[package]] +name = "zipp" +version = "3.20.2" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, + {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] + +[metadata] +lock-version = "2.0" +python-versions = ">=3.9,<3.13" +content-hash = "d169e52a2daaf41488cf7191e3f4ff209d39125941a6d0e656001aae534b1e8d" diff --git a/backends/gaudi/server/pyproject.toml b/backends/gaudi/server/pyproject.toml new file mode 100644 index 000000000..c61ac030b --- /dev/null +++ b/backends/gaudi/server/pyproject.toml @@ -0,0 +1,42 @@ +[tool.poetry] +name = "text-generation-server" +version = "2.0.4" +description = "Text Generation Inference Python gRPC Server" +authors = ["Olivier Dehaene "] + +[tool.poetry.scripts] +text-generation-server = 'text_generation_server.cli:app' + +[tool.poetry.dependencies] +python = ">=3.9,<3.13" +protobuf = "^3.20.3" +grpcio = "^1.51.1" +grpcio-status = "*" +grpcio-reflection = "*" +grpc-interceptor = "^0.15.0" +typer = "^0.7.0" +loguru = "^0.6.0" +opentelemetry-api = "^1.15.0" +opentelemetry-exporter-otlp = "^1.15.0" +opentelemetry-instrumentation-grpc = "^0.36b0" +hf-transfer = "^0.1.2" +sentencepiece = "^0.1.97" +peft = "^0.10" +optimum-habana = "1.15.0" +transformers = "4.45.2" +numpy = "1.26.4" +accelerate = "0.33.0" +outlines= { version = "^0.0.36", optional = true } +prometheus-client = "^0.20.0" +py-cpuinfo = "^9.0.0" + +[tool.poetry.group.dev.dependencies] +grpcio-tools = "*" +pytest = "^7.3.0" + +[tool.pytest.ini_options] +markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt new file mode 100644 index 000000000..074144909 --- /dev/null +++ b/backends/gaudi/server/requirements.txt @@ -0,0 +1,89 @@ +accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" +aiohappyeyeballs==2.4.3 ; python_version >= "3.9" and python_version < "3.13" +aiohttp==3.10.10 ; python_version >= "3.9" and python_version < "3.13" +aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" +async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11" +attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13" +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13" +datasets==3.0.1 ; python_version >= "3.9" and python_version < "3.13" +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13" +dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" +frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.26.1 ; python_version >= "3.9" and python_version < "3.13" +humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13" +idna==3.10 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==8.5.0 ; python_version >= "3.9" and python_version < "3.13" +jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13" +joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13" +mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" +multidict==6.1.0 ; python_version >= "3.9" and python_version < "3.13" +multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13" +networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13" +peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13" +pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13" +prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +propcache==0.2.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13" +psutil==6.1.0 ; python_version >= "3.9" and python_version < "3.13" +py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13" +pyreadline3==3.5.4 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13" +python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13" +pytz==2024.2 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" +scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" +sentence-transformers[train]==3.2.1 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" +setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13" +six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13" +threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13" +transformers[sentencepiece]==4.45.2 ; python_version >= "3.9" and python_version < "3.13" +triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9" +typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +tzdata==2024.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13" +yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/requirements_cuda.txt b/backends/gaudi/server/requirements_cuda.txt new file mode 100644 index 000000000..5de75b6b3 --- /dev/null +++ b/backends/gaudi/server/requirements_cuda.txt @@ -0,0 +1,54 @@ +certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" +idna==3.10 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" +prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" +py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/requirements_intel.txt b/backends/gaudi/server/requirements_intel.txt new file mode 100644 index 000000000..5de75b6b3 --- /dev/null +++ b/backends/gaudi/server/requirements_intel.txt @@ -0,0 +1,54 @@ +certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" +idna==3.10 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" +prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" +py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/requirements_rocm.txt b/backends/gaudi/server/requirements_rocm.txt new file mode 100644 index 000000000..5de75b6b3 --- /dev/null +++ b/backends/gaudi/server/requirements_rocm.txt @@ -0,0 +1,54 @@ +certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" +idna==3.10 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" +prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" +py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/tests/conftest.py b/backends/gaudi/server/tests/conftest.py new file mode 100644 index 000000000..d99771f8a --- /dev/null +++ b/backends/gaudi/server/tests/conftest.py @@ -0,0 +1,23 @@ +import pytest +import os +from text_generation_server.pb import generate_pb2 + +os.environ["USE_PREFIX_CACHING"] = "1" +os.environ["ATTENTION"] = "flashinfer" + + +@pytest.fixture +def default_pb_parameters(): + return generate_pb2.NextTokenChooserParameters( + temperature=1.0, + repetition_penalty=1.0, + top_k=0, + top_p=1.0, + typical_p=1.0, + do_sample=False, + ) + + +@pytest.fixture +def default_pb_stop_parameters(): + return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) diff --git a/backends/gaudi/server/tests/models/test_bloom.py b/backends/gaudi/server/tests/models/test_bloom.py new file mode 100644 index 000000000..dff5239b3 --- /dev/null +++ b/backends/gaudi/server/tests/models/test_bloom.py @@ -0,0 +1,363 @@ +import pytest +import torch + +from copy import copy +from transformers import AutoTokenizer + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.utils import weight_hub_files, download_weights +from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) + + +@pytest.fixture(scope="session") +def default_bloom(): + model_id = "bigscience/bloom-560m" + revision = "main" + filenames = weight_hub_files(model_id, revision, ".safetensors") + download_weights(filenames, model_id, revision) + return BLOOMSharded( + model_id, + model_class=BloomForCausalLM, + ) + + +@pytest.fixture(scope="session") +def bloom_560m_tokenizer(): + return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left") + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): + return BloomCausalLMBatch.from_pb( + default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu") + ) + + +@pytest.fixture +def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer): + req_0 = copy(default_pb_request) + req_0.id = 1 + req_1 = default_pb_request + req_1.id = 2 + req_1.stopping_parameters.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) + return BloomCausalLMBatch.from_pb( + batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu") + ) + + +def test_batch_from_pb(default_pb_batch, default_bloom_batch): + batch = default_bloom_batch + + assert batch.batch_id == default_pb_batch.id + assert batch.requests == default_pb_batch.requests + + assert len(batch.input_ids) == default_pb_batch.size + assert batch.input_ids[0][-1] == 10264 + assert torch.all(batch.input_ids[0][:-1] == 3) + + assert batch.attention_mask[0][0] == 1 + assert torch.all(batch.attention_mask[0][1:] == 0) + + assert batch.past_key_values is None + + assert all( + [ + torch.equal(input_ids, all_input_ids[:, 0]) + for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) + ] + ) + + assert batch.input_lengths == [1] + + assert len(batch) == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) + + assert batch.max_input_length == batch.input_lengths[0] + + +def test_batch_concatenate_no_prefill(default_bloom_batch): + with pytest.raises(ValueError): + BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch]) + + +def test_causal_lm_batch_type(default_bloom): + assert default_bloom.batch_type == BloomCausalLMBatch + + +def test_causal_lm_generate_token(default_bloom, default_bloom_batch): + sequence_length = len(default_bloom_batch.all_input_ids[0]) + generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch) + + assert len(generations) == len(default_bloom_batch) + assert isinstance(next_batch, CausalLMBatch) + assert not next_batch.keys_head_dim_last + + assert len(next_batch.all_input_ids) == len(next_batch) + assert len(next_batch.all_input_ids[0]) == sequence_length + 1 + assert len(next_batch.attention_mask[0]) == 11 + assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) + assert torch.all(next_batch.all_input_ids[0][:-2] == 3) + + assert torch.all(next_batch.attention_mask[0][:2] == 1) + assert torch.all(next_batch.attention_mask[0][2:] == 0) + + assert next_batch.input_ids.shape == (len(next_batch), 1) + assert next_batch.input_ids[0, 0] == 10264 + + assert next_batch.input_lengths == [2] + assert next_batch.max_input_length == next_batch.input_lengths[0] + + assert next_batch.past_key_values is not None + assert all( + [p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values] + ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all( + [ + token_id.item() == 10264 + for generation in generations + for token_id in generation.tokens.token_ids + ] + ) + assert all( + [ + token_text == "Test" + for generation in generations + for token_text in generation.tokens.texts + ] + ) + assert generations[0].request_id == 0 + + +def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): + next_batch = default_bloom_batch + for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert len(generations) == len(default_bloom_batch) + + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert ( + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" + ) + assert generations[0].request_id == default_bloom_batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == default_bloom_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_causal_lm_generate_token_completion_multi( + default_bloom, default_multi_requests_bloom_batch +): + next_batch = default_multi_requests_bloom_batch + + for i in range( + default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 + ): + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert len(generations) == len(default_multi_requests_bloom_batch) + + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 2 + assert generations[1].generated_text.text == "TestTestTestTestTest" + assert ( + generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id + ) + assert ( + generations[1].generated_text.generated_tokens + == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + ) + # Copy stopping_criterias before filtering + stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() + + next_batch = next_batch.filter([next_batch.requests[0].id]) + + for _ in range( + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 + ): + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert ( + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" + ) + assert ( + generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens + == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_batch_concatenate( + default_bloom, default_bloom_batch, default_multi_requests_bloom_batch +): + next_batch_0 = default_bloom_batch + _, next_batch_0, _ = default_bloom.generate_token(next_batch_0) + _, next_batch_0, _ = default_bloom.generate_token(next_batch_0) + + next_batch_1 = default_multi_requests_bloom_batch + _, next_batch_1, _ = default_bloom.generate_token(next_batch_1) + + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values + ] + + next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1]) + + assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) + + assert torch.all( + next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1 + ) + assert torch.all( + next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1 + ) + assert torch.all(next_batch.attention_mask[1:, 3:] == 0) + + assert next_batch.batch_id == 0 + assert torch.all(next_batch.input_ids == 10264) + + assert next_batch.input_lengths == [3, 2, 2] + assert next_batch.max_input_length == 3 + + assert next_batch.requests[0] == next_batch_0.requests[0] + assert next_batch.requests[1:] == list(next_batch_1.requests) + + assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] + assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers + + assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] + assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias + + assert next_batch.past_key_values is not None + assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values]) + assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values]) + + for i, past in enumerate(next_batch.past_key_values): + assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0]) + assert torch.equal( + next_batch_1_past_key_values[i][0][:, :, -1:], + past[0][1:, :, :, -1].reshape(-1, 64, 1), + ) + + assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0]) + assert torch.equal( + next_batch_1_past_key_values[i][1][:, -1:, :], + past[1][1:, :, -1, :].reshape(-1, 1, 64), + ) + + for _ in range( + default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 + ): + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 3 + assert generations[2].generated_text.text == "TestTestTestTestTest" + assert ( + generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id + ) + assert ( + generations[2].generated_text.generated_tokens + == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + ) + + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) + + for _ in range( + default_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 2 + ): + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 2 + assert ( + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" + ) + assert generations[0].request_id == default_bloom_batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == default_bloom_batch.stopping_criterias[0].max_new_tokens + ) + + next_batch = next_batch.filter([next_batch.requests[1].id]) + + for _ in range( + default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + - default_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 4 + ): + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_bloom.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert ( + generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest" + ) + assert ( + generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens + == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + ) diff --git a/backends/gaudi/server/tests/models/test_causal_lm.py b/backends/gaudi/server/tests/models/test_causal_lm.py new file mode 100644 index 000000000..11ca5efe3 --- /dev/null +++ b/backends/gaudi/server/tests/models/test_causal_lm.py @@ -0,0 +1,354 @@ +import pytest +import torch + +from copy import copy +from transformers import AutoTokenizer + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch + + +@pytest.fixture(scope="session") +def default_causal_lm(): + return CausalLM.fallback("gpt2") + + +@pytest.fixture(scope="session") +def gpt2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + tokenizer.pad_token_id = 50256 + return tokenizer + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): + return CausalLMBatch.from_pb( + default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu") + ) + + +@pytest.fixture +def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): + req_0 = copy(default_pb_request) + req_0.id = 1 + req_1 = default_pb_request + req_1.id = 2 + req_1.stopping_parameters.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) + return CausalLMBatch.from_pb( + batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu") + ) + + +def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): + batch = default_causal_lm_batch + + assert batch.batch_id == default_pb_batch.id + assert batch.requests == default_pb_batch.requests + + assert len(batch.input_ids) == default_pb_batch.size + assert batch.input_ids[0][-1] == 14402 + assert torch.all(batch.input_ids[0][:-1] == 50256) + + assert batch.attention_mask[0, 0] == 1 + assert torch.all(batch.attention_mask[0, 1:] == 0) + + assert batch.past_key_values is None + + assert all( + [ + torch.equal(input_ids, all_input_ids[:, 0]) + for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) + ] + ) + + assert batch.input_lengths == [1] + + assert len(batch) == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) + + assert batch.max_input_length == batch.input_lengths[0] + + +def test_batch_concatenate_no_prefill(default_causal_lm_batch): + with pytest.raises(ValueError): + CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch]) + + +def test_causal_lm_batch_type(default_causal_lm): + assert default_causal_lm.batch_type == CausalLMBatch + + +def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): + sequence_length = len(default_causal_lm_batch.all_input_ids[0]) + generations, next_batch, _ = default_causal_lm.generate_token( + default_causal_lm_batch + ) + + assert len(generations) == len(next_batch) + assert isinstance(next_batch, CausalLMBatch) + + assert len(next_batch.all_input_ids) == len(next_batch) + assert len(next_batch.all_input_ids[0]) == sequence_length + 1 + assert len(next_batch.attention_mask[0]) == 11 + assert next_batch.all_input_ids[0][-1] == 13 + assert next_batch.all_input_ids[0][-2] == 14402 + assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) + + assert torch.all(next_batch.attention_mask[0][0:2] == 1) + assert torch.all(next_batch.attention_mask[0][2:] == 0) + + assert next_batch.input_ids.shape == (len(next_batch), 1) + assert next_batch.input_ids[0, 0] == 13 + + assert next_batch.input_lengths == [2] + assert next_batch.max_input_length == next_batch.input_lengths[0] + + assert next_batch.past_key_values is not None + assert all( + [p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] + ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all( + [ + token_id.item() == 13 + for generation in generations + for token_id in generation.tokens.token_ids + ] + ) + assert all( + [ + token_text == "." + for generation in generations + for token_text in generation.tokens.texts + ] + ) + assert generations[0].request_id == 0 + + +def test_causal_lm_generate_token_completion( + default_causal_lm, default_causal_lm_batch +): + next_batch = default_causal_lm_batch + for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == ".java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_causal_lm_generate_token_completion_multi( + default_causal_lm, default_multi_requests_causal_lm_batch +): + next_batch = default_multi_requests_causal_lm_batch + + for i in range( + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 + ): + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 2 + assert generations[1].generated_text.text == ".java:784)" + assert ( + generations[1].request_id + == default_multi_requests_causal_lm_batch.requests[1].id + ) + assert ( + generations[1].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + ) + # Copy stopping_criterias before filtering + stopping_criterias = ( + default_multi_requests_causal_lm_batch.stopping_criterias.copy() + ) + + next_batch = next_batch.filter([next_batch.requests[0].id]) + + for _ in range( + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 + ): + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == ".java:784) at net.minecraft." + assert ( + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_batch_concatenate( + default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch +): + next_batch_0 = default_causal_lm_batch + _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0) + + next_batch_1 = default_multi_requests_causal_lm_batch + _, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1) + + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values + ] + + next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1]) + + assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) + + assert torch.all( + next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1 + ) + assert torch.all( + next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1 + ) + assert torch.all(next_batch.attention_mask[1:, 3:] == 0) + + assert next_batch.batch_id == 0 + assert next_batch.input_ids[0, 0] == 12355 + assert torch.all(next_batch.input_ids[1:] == 13) + + assert next_batch.input_lengths == [3, 2, 2] + assert next_batch.max_input_length == 3 + + assert next_batch.requests[0] == next_batch_0.requests[0] + assert next_batch.requests[1:] == list(next_batch_1.requests) + + assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] + assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers + + assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] + assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias + + assert next_batch.past_key_values is not None + assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) + assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) + + for i, past in enumerate(next_batch.past_key_values): + assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0]) + assert torch.equal( + next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] + ) + + assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0]) + assert torch.equal( + next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] + ) + + for _ in range( + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 + ): + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 3 + assert generations[2].generated_text.text == ".java:784)" + assert ( + generations[2].request_id + == default_multi_requests_causal_lm_batch.requests[1].id + ) + assert ( + generations[2].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + ) + + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) + + for _ in range( + default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 2 + ): + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 2 + assert generations[0].generated_text.text == ".java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) + + next_batch = next_batch.filter([next_batch.requests[1].id]) + + for _ in range( + default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 4 + ): + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == ".java:784) at net.minecraft." + assert ( + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) diff --git a/backends/gaudi/server/tests/models/test_model.py b/backends/gaudi/server/tests/models/test_model.py new file mode 100644 index 000000000..8441e8c6e --- /dev/null +++ b/backends/gaudi/server/tests/models/test_model.py @@ -0,0 +1,83 @@ +import pytest +import torch + +from transformers import AutoTokenizer + +from text_generation_server.models import Model + + +def get_test_model(): + class TestModel(Model): + def batch_type(self): + raise NotImplementedError + + def generate_token(self, batch): + raise NotImplementedError + + tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") + + model = TestModel( + "test_model_id", + torch.nn.Linear(1, 1), + tokenizer, + False, + torch.float32, + torch.device("cpu"), + ) + return model + + +@pytest.mark.private +def test_decode_streaming_english_spaces(): + model = get_test_model() + truth = "Hello here, this is a simple test" + all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243] + assert ( + all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"] + ) + + decoded_text = "" + offset = 0 + token_offset = 0 + for i in range(len(all_input_ids)): + text, offset, token_offset = model.decode_token( + all_input_ids[: i + 1], offset, token_offset + ) + decoded_text += text + + assert decoded_text == truth + + +@pytest.mark.private +def test_decode_streaming_chinese_utf8(): + model = get_test_model() + truth = "我很感谢你的热情" + all_input_ids = [ + 30672, + 232, + 193, + 139, + 233, + 135, + 162, + 235, + 179, + 165, + 30919, + 30210, + 234, + 134, + 176, + 30993, + ] + + decoded_text = "" + offset = 0 + token_offset = 0 + for i in range(len(all_input_ids)): + text, offset, token_offset = model.decode_token( + all_input_ids[: i + 1], offset, token_offset + ) + decoded_text += text + + assert decoded_text == truth diff --git a/backends/gaudi/server/tests/models/test_santacoder.py b/backends/gaudi/server/tests/models/test_santacoder.py new file mode 100644 index 000000000..d5c91bffc --- /dev/null +++ b/backends/gaudi/server/tests/models/test_santacoder.py @@ -0,0 +1,108 @@ +import pytest + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM + + +@pytest.fixture(scope="session") +def default_santacoder(): + return CausalLM.fallback(model_id="bigcode/santacoder") + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="def", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]), + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="defworld", + input_chunks=generate_pb2.Input( + chunks=[ + generate_pb2.InputChunk( + text="defworld" + ) + ] + ), + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_fim_pb_batch(default_fim_pb_request): + return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1) + + +@pytest.mark.skip +def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch): + batch = CausalLMBatch.from_pb( + default_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch, _ = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == " test_get_all_users_with_" + assert generations[0].request_id == batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) + + +@pytest.mark.skip +def test_fim_santacoder_generate_token_completion( + default_santacoder, default_fim_pb_batch +): + batch = CausalLMBatch.from_pb( + default_fim_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch, _ = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert ( + generations[0].generated_text.text + == """ineProperty(exports, "__esModule", { value""" + ) + assert generations[0].request_id == batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) diff --git a/backends/gaudi/server/tests/models/test_seq2seq_lm.py b/backends/gaudi/server/tests/models/test_seq2seq_lm.py new file mode 100644 index 000000000..5ba7c64ce --- /dev/null +++ b/backends/gaudi/server/tests/models/test_seq2seq_lm.py @@ -0,0 +1,365 @@ +import pytest +import torch + +from copy import copy + +from transformers import AutoTokenizer + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch + + +@pytest.fixture(scope="session") +def mt0_small_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained( + "bigscience/mt0-small", padding_side="left" + ) + tokenizer.bos_token_id = 0 + return tokenizer + + +@pytest.fixture(scope="session") +def default_seq2seq_lm(): + return Seq2SeqLM.fallback("bigscience/mt0-small") + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer): + return Seq2SeqLMBatch.from_pb( + default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu") + ) + + +@pytest.fixture +def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer): + req_0 = copy(default_pb_request) + req_0.id = 1 + req_1 = default_pb_request + req_1.id = 2 + req_1.stopping_parameters.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) + return Seq2SeqLMBatch.from_pb( + batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu") + ) + + +def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): + batch = default_seq2seq_lm_batch + sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) + + assert batch.batch_id == default_pb_batch.id + assert batch.requests == default_pb_batch.requests + + assert batch.input_ids.shape == (default_pb_batch.size, sequence_length) + assert batch.input_ids[0][-2] == 4268 + assert batch.input_ids[0][-1] == 1 + assert torch.all(batch.input_ids[0][:-2] == 0) + + assert torch.all(batch.attention_mask[0][-2:] == 1) + assert torch.all(batch.attention_mask[0][:-2] == 0) + + assert len(batch.decoder_input_ids) == default_pb_batch.size + assert batch.decoder_attention_mask is None + assert batch.encoder_last_hidden_state is None + + assert batch.past_key_values is None + + assert batch.input_lengths == [2] + assert batch.decoder_input_lengths == [1] + + assert len(batch) == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) + + assert batch.max_input_length == batch.input_lengths[0] + assert batch.max_decoder_input_length == batch.decoder_input_lengths[0] + + +def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch): + with pytest.raises(ValueError): + Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch]) + + +def test_seq2seq_lm_batch_type(default_seq2seq_lm): + assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch + + +def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): + sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) + generations, next_batch, _ = default_seq2seq_lm.generate_token( + default_seq2seq_lm_batch + ) + + assert len(generations) == len(next_batch) + assert isinstance(next_batch, Seq2SeqLMBatch) + + assert next_batch.input_ids is None + assert torch.equal( + next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask + ) + assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths + assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length + assert ( + next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers + ) + assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias + + assert len(next_batch.decoder_input_ids) == len(next_batch) + assert next_batch.all_decoder_input_ids[0][0] == 0 + assert next_batch.all_decoder_input_ids[0][1] == 259 + assert next_batch.decoder_attention_mask is None + assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512) + + assert next_batch.decoder_input_lengths == [2] + assert next_batch.max_decoder_input_length == 2 + + assert next_batch.past_key_values is not None + assert all( + [p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values] + ) + assert all( + [ + p[2].shape == (len(next_batch), 6, sequence_length, 64) + for p in next_batch.past_key_values + ] + ) + assert all( + [ + p[3].shape == (len(next_batch), 6, sequence_length, 64) + for p in next_batch.past_key_values + ] + ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all( + [ + token_id.item() == 259 + for generation in generations + for token_id in generation.tokens.token_ids + ] + ) + assert all( + [ + token_text == " " + for generation in generations + for token_text in generation.tokens.texts + ] + ) + assert generations[0].request_id == 0 + + +def test_seq2seq_lm_generate_token_completion( + default_seq2seq_lm, default_seq2seq_lm_batch +): + next_batch = default_seq2seq_lm_batch + for _ in range(6): + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" + assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id + assert generations[0].generated_text.generated_tokens == 7 + + +def test_seq2seq_lm_generate_token_completion_multi( + default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch +): + next_batch = default_multi_requests_seq2seq_lm_batch + + for i in range(4): + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 2 + assert generations[1].generated_text.text == "a few " + assert ( + generations[1].request_id + == default_multi_requests_seq2seq_lm_batch.requests[1].id + ) + assert generations[1].generated_text.generated_tokens == 5 + + next_batch = next_batch.filter([next_batch.requests[0].id]) + + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" + assert ( + generations[0].request_id + == default_multi_requests_seq2seq_lm_batch.requests[0].id + ) + assert generations[0].generated_text.generated_tokens == 7 + + +def test_batch_concatenate( + default_seq2seq_lm, + default_seq2seq_lm_batch, + default_multi_requests_seq2seq_lm_batch, +): + next_batch_0 = default_seq2seq_lm_batch + _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0) + + next_batch_1 = default_multi_requests_seq2seq_lm_batch + _, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1) + + # Copy hidden state because it is removed from the concatenated branches + next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state + next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state + + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [ + [t.clone() for t in layer] for layer in next_batch_0.past_key_values + ] + next_batch_1_past_key_values = [ + [t.clone() for t in layer] for layer in next_batch_1.past_key_values + ] + + next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1]) + + assert next_batch.batch_id == 0 + + assert torch.equal( + next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] + ) + assert next_batch.all_decoder_input_ids[1][0] == 0 + assert next_batch.all_decoder_input_ids[2][0] == 0 + assert torch.equal( + next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids + ) + + assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1) + assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0) + assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0) + assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1) + + assert torch.equal( + next_batch.encoder_last_hidden_state[0], + next_batch_0_encoder_last_hidden_state[0, -2:], + ) + assert torch.equal( + next_batch.encoder_last_hidden_state[1:], + next_batch_1_encoder_last_hidden_state[:, -2:], + ) + + assert next_batch.input_lengths == [2, 2, 2] + assert next_batch.decoder_input_lengths == [3, 2, 2] + assert next_batch.max_input_length == 2 + assert next_batch.max_decoder_input_length == 3 + + assert next_batch.requests[0] == next_batch_0.requests[0] + assert next_batch.requests[1:] == list(next_batch_1.requests) + + assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] + assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers + + assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] + assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias + + assert next_batch.past_key_values is not None + assert all( + [p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values] + ) + + for i, past in enumerate(next_batch.past_key_values): + assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0]) + assert torch.equal( + next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :] + ) + + assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0]) + assert torch.equal( + next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :] + ) + + assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0]) + assert torch.equal( + next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:] + ) + + assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0]) + assert torch.equal( + next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:] + ) + + for _ in range(3): + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 3 + assert generations[2].generated_text.text == "a few " + assert ( + generations[2].request_id + == default_multi_requests_seq2seq_lm_batch.requests[1].id + ) + assert generations[2].generated_text.generated_tokens == 5 + + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) + + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generations) == 2 + assert generations[0].generated_text.text == "a few weeks" + assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id + assert generations[0].generated_text.generated_tokens == 7 + + next_batch = next_batch.filter([next_batch.requests[1].id]) + + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" + assert ( + generations[0].request_id + == default_multi_requests_seq2seq_lm_batch.requests[0].id + ) + assert generations[0].generated_text.generated_tokens == 7 diff --git a/backends/gaudi/server/tests/utils/test_adapter.py b/backends/gaudi/server/tests/utils/test_adapter.py new file mode 100644 index 000000000..a27c10551 --- /dev/null +++ b/backends/gaudi/server/tests/utils/test_adapter.py @@ -0,0 +1,247 @@ +import pytest +from unittest.mock import Mock +from text_generation_server.utils.adapter import ( + get_attn_weights, + get_mlp_weights, + parse_lora_adapters, + AdapterInfo, +) + + +def test_parse_lora_adapters_empty(): + assert parse_lora_adapters(None) == [] + assert parse_lora_adapters("") == [] + + +def test_parse_lora_adapters_single(): + result = parse_lora_adapters("adapter1") + assert result == [AdapterInfo(id="adapter1", path=None, revision=None)] + + +def test_parse_lora_adapters_with_path(): + result = parse_lora_adapters("adapter1=path/to/adapter1") + assert result == [ + AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None) + ] + + +def test_parse_lora_adapters_with_path_and_revision(): + result = parse_lora_adapters("adapter1=path/to/adapter1@main") + assert result == [ + AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main") + ] + + +def test_parse_lora_adapters_multiple(): + result = parse_lora_adapters( + "adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev" + ) + assert result == [ + AdapterInfo(id="adapter1", path=None, revision=None), + AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None), + AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"), + ] + + +def test_parse_lora_adapters_invalid_format(): + try: + parse_lora_adapters("adapter1,invalid=format=test,adapter3") + assert False, "Should have raised ValueError" + except ValueError as e: + assert str(e) == "Invalid LoRA adapter format: invalid=format=test" + + +def test_get_attn_weights(): + # create a mock layer + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + # call the function + result = get_attn_weights(2, mock_layer) + + # assert the result + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_with_gate_up_proj(): + # create a mock layer with gate_up_proj + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # call the function + result = get_mlp_weights(3, mock_layer) + + # assert the result + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_mlp_weights_without_gate_up_proj(): + # create a mock layer without gate_up_proj + mock_layer = Mock() + mock_layer.mlp = Mock(spec=[]) + + # call the function + result = get_mlp_weights(1, mock_layer) + + # assert the result + assert result == {} + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_attn_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(layer_index, mock_layer) + + for k in ["q", "k", "v"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.self_attn.{k}_proj" + ) + + assert (layer_index, "o_proj") in result + assert ( + result[(layer_index, "o_proj")][0] + == f"model.layers.{layer_index}.self_attn.o_proj" + ) + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_mlp_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(layer_index, mock_layer) + + for k in ["gate", "up", "down"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.mlp.{k}_proj" + ) + + +def test_get_attn_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_attn_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_proj = Mock() + mock_layer.mlp.up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # ensure that the mock_layer.mlp.gate_up_proj attribute does not exist. + # This is necessary because the use of `Mock` automatically creates any + # attributes that are accessed, even if they don't exist in the actual + # implementation. If `gate_up_proj` were created, `get_mlp_weights` might + # follow the wrong execution path and return an incorrect result. + del mock_layer.mlp.gate_up_proj + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected diff --git a/backends/gaudi/server/tests/utils/test_convert.py b/backends/gaudi/server/tests/utils/test_convert.py new file mode 100644 index 000000000..ba6c57023 --- /dev/null +++ b/backends/gaudi/server/tests/utils/test_convert.py @@ -0,0 +1,21 @@ +from text_generation_server.utils.hub import ( + download_weights, + weight_hub_files, + weight_files, +) + +from text_generation_server.utils.convert import convert_files + + +def test_convert_files(): + model_id = "bigscience/bloom-560m" + pt_filenames = weight_hub_files(model_id, extension=".bin") + local_pt_files = download_weights(pt_filenames, model_id) + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files + ] + convert_files(local_pt_files, local_st_files, discard_names=[]) + + found_st_files = weight_files(model_id) + + assert all([p in found_st_files for p in local_st_files]) diff --git a/backends/gaudi/server/tests/utils/test_hub.py b/backends/gaudi/server/tests/utils/test_hub.py new file mode 100644 index 000000000..291a41b05 --- /dev/null +++ b/backends/gaudi/server/tests/utils/test_hub.py @@ -0,0 +1,103 @@ +import os +import tempfile + +import pytest + +import huggingface_hub.constants + +import text_generation_server.utils.hub +from text_generation_server.utils.hub import ( + weight_hub_files, + download_weights, + weight_files, + EntryNotFoundError, + LocalEntryNotFoundError, + RevisionNotFoundError, +) + + +@pytest.fixture() +def offline(): + current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE + text_generation_server.utils.hub.HF_HUB_OFFLINE = True + yield "offline" + text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value + + +@pytest.fixture() +def fresh_cache(): + with tempfile.TemporaryDirectory() as d: + current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d + text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d + os.environ["HUGGINGFACE_HUB_CACHE"] = d + yield + huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value + os.environ["HUGGINGFACE_HUB_CACHE"] = current_value + text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value + + +@pytest.fixture() +def prefetched(): + model_id = "bert-base-uncased" + huggingface_hub.snapshot_download( + repo_id=model_id, + revision="main", + local_files_only=False, + repo_type="model", + allow_patterns=["*.safetensors"], + ) + yield model_id + + +def test_weight_hub_files_offline_error(offline, fresh_cache): + # If the model is not prefetched then it will raise an error + with pytest.raises(EntryNotFoundError): + weight_hub_files("gpt2") + + +def test_weight_hub_files_offline_ok(prefetched, offline): + # If the model is prefetched then we should be able to get the weight files from local cache + filenames = weight_hub_files(prefetched) + root = None + assert len(filenames) == 1 + for f in filenames: + curroot, filename = os.path.split(f) + if root is None: + root = curroot + else: + assert root == curroot + assert filename == "model.safetensors" + + +def test_weight_hub_files(): + filenames = weight_hub_files("bigscience/bloom-560m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] + + +def test_weight_hub_files_empty(): + with pytest.raises(EntryNotFoundError): + weight_hub_files("bigscience/bloom", extension=".errors") + + +def test_download_weights(): + model_id = "bigscience/bloom-560m" + filenames = weight_hub_files(model_id) + files = download_weights(filenames, model_id) + local_files = weight_files("bigscience/bloom-560m") + assert files == local_files + + +def test_weight_files_revision_error(): + with pytest.raises(RevisionNotFoundError): + weight_files("bigscience/bloom-560m", revision="error") + + +def test_weight_files_not_cached_error(fresh_cache): + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased") diff --git a/backends/gaudi/server/tests/utils/test_layers.py b/backends/gaudi/server/tests/utils/test_layers.py new file mode 100644 index 000000000..118540eed --- /dev/null +++ b/backends/gaudi/server/tests/utils/test_layers.py @@ -0,0 +1,82 @@ +import torch +from text_generation_server.layers import ( + TensorParallelEmbedding, +) + + +class ProcessGroup: + def __init__(self, rank: int, world_size: int): + self._rank = rank + self.world_size = world_size + + def size(self) -> int: + return self.world_size + + def rank(self) -> int: + return self._rank + + +class Weights: + def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int): + self.weight = ( + torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim) + ) + self.process_group = ProcessGroup(rank, world_size) + + def get_partial_sharded(self, name: str, dim: int): + assert dim == 0 + + rank = self.process_group.rank() + world_size = self.process_group.size() + size = self.weight.shape[dim] + + block_size = (size + world_size - 1) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + return self.weight[start:stop] + + def get_shape(self, name: str): + return self.weight.shape + + +def test_weight_hub_files_offline_error(): + + vocab_size = 17 + weights = Weights( + rank=0, + world_size=1, + vocab_size=vocab_size, + hidden_dim=256, + ) + embeddings = TensorParallelEmbedding("", weights) + + input_ids = torch.arange(vocab_size) + output = embeddings.forward(input_ids) + assert embeddings.min_id == 0 + assert embeddings.max_id == 17 + torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256)) + + weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256) + weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256) + embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False) + assert embeddings_0_2.min_id == 0 + assert embeddings_0_2.max_id == 9 + torch.testing.assert_close( + embeddings_0_2.weight, + torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0) + .view(10, 256) + .float(), + ) + embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False) + assert embeddings_1_2.min_id == 9 + assert embeddings_1_2.max_id == 17 + torch.testing.assert_close( + embeddings_1_2.weight, + torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0) + .view(9, 256) + .float(), + ) + output_tp_0 = embeddings_0_2.forward(input_ids) + output_tp_1 = embeddings_1_2.forward(input_ids) + + torch.testing.assert_close(output, output_tp_0 + output_tp_1) diff --git a/backends/gaudi/server/tests/utils/test_tokens.py b/backends/gaudi/server/tests/utils/test_tokens.py new file mode 100644 index 000000000..5db327767 --- /dev/null +++ b/backends/gaudi/server/tests/utils/test_tokens.py @@ -0,0 +1,88 @@ +import torch +from text_generation_server.utils.tokens import ( + StopSequenceCriteria, + StoppingCriteria, + FinishReason, + batch_top_tokens, +) + + +def test_stop_sequence_criteria(): + criteria = StopSequenceCriteria("/test;") + + assert not criteria("/") + assert not criteria("/test") + assert criteria("/test;") + assert not criteria("/test; ") + + +def test_stop_sequence_criteria_escape(): + criteria = StopSequenceCriteria("<|stop|>") + + assert not criteria("<") + assert not criteria("<|stop") + assert criteria("<|stop|>") + assert not criteria("<|stop|> ") + + +def test_stopping_criteria(): + criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) + assert criteria(65827, "/test") == (False, None) + assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE) + + +def test_stopping_criteria_eos(): + criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) + assert criteria(1, "") == (False, None) + assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN) + + +def test_stopping_criteria_max(): + criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) + + +def test_batch_top_tokens(): + top_n_tokens = [0, 2, 3, 4, 5] + top_n_tokens_tensor = torch.tensor(top_n_tokens) + inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5) + accepted_ids = torch.ones_like(top_n_tokens_tensor) + + topn_tok_ids, topn_tok_logprobs = batch_top_tokens( + top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids + ) + + assert topn_tok_ids[0] == [[]] + assert topn_tok_ids[1] == [[0, 3]] + assert topn_tok_ids[2] == [[0, 3, 1, 4]] + assert topn_tok_ids[3] == [[0, 3, 1, 4]] + assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]] + + assert topn_tok_logprobs[0] == [[]] + assert topn_tok_logprobs[1] == [[-1, -2]] + assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]] + + # Now let's make second member of the batch be speculated + inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2) + accepted_ids[1] = 2 + topn_tok_ids, topn_tok_logprobs = batch_top_tokens( + top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids + ) + + assert topn_tok_ids[0] == [[]] + assert topn_tok_ids[1] == [[0, 3], [0, 3]] + assert topn_tok_ids[2] == [[0, 3, 1, 4]] + assert topn_tok_ids[3] == [[0, 3, 1, 4]] + assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]] + + assert topn_tok_logprobs[0] == [[]] + assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]] + assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]] diff --git a/backends/gaudi/server/tests/utils/test_watermark.py b/backends/gaudi/server/tests/utils/test_watermark.py new file mode 100644 index 000000000..5d0aaf050 --- /dev/null +++ b/backends/gaudi/server/tests/utils/test_watermark.py @@ -0,0 +1,54 @@ +# test_watermark_logits_processor.py + +import os +import numpy as np +import torch +from text_generation_server.utils.watermark import WatermarkLogitsProcessor + + +GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) +DELTA = os.getenv("WATERMARK_DELTA", 2.0) + + +def test_seed_rng(): + input_ids = [101, 2036, 3731, 102, 2003, 103] + processor = WatermarkLogitsProcessor() + processor._seed_rng(input_ids) + assert isinstance(processor.rng, torch.Generator) + + +def test_get_greenlist_ids(): + input_ids = [101, 2036, 3731, 102, 2003, 103] + processor = WatermarkLogitsProcessor() + result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu")) + assert max(result) <= 10 + assert len(result) == int(10 * 0.5) + + +def test_calc_greenlist_mask(): + processor = WatermarkLogitsProcessor() + scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) + greenlist_token_ids = torch.tensor([2, 3]) + result = processor._calc_greenlist_mask(scores, greenlist_token_ids) + assert result.tolist() == [[False, False, False, False], [False, False, True, True]] + assert result.shape == scores.shape + + +def test_bias_greenlist_logits(): + processor = WatermarkLogitsProcessor() + scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) + green_tokens_mask = torch.tensor( + [[False, False, True, True], [False, False, False, True]] + ) + greenlist_bias = 2.0 + result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias) + assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]]) + assert result.shape == scores.shape + + +def test_call(): + input_ids = [101, 2036, 3731, 102, 2003, 103] + processor = WatermarkLogitsProcessor() + scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) + result = processor(input_ids, scores) + assert result.shape == scores.shape diff --git a/backends/gaudi/server/tests/utils/test_weights.py b/backends/gaudi/server/tests/utils/test_weights.py new file mode 100644 index 000000000..556fcea1e --- /dev/null +++ b/backends/gaudi/server/tests/utils/test_weights.py @@ -0,0 +1,1177 @@ +import pytest +import torch +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + Weights, + WeightsLoader, +) +from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader +from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader +from text_generation_server.layers.marlin.marlin import ( + MarlinWeight, + MarlinWeightsLoader, +) +from types import SimpleNamespace +from typing import List, Optional, Dict, Union +from pathlib import Path + + +@pytest.fixture +def gptq_weights_loader(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="gptq", + quantize="gptq", + sym=True, + ) + + +@pytest.fixture +def gptq_weights_loader_awq(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="awq", + quantize="awq", + sym=True, + ) + + +@pytest.fixture +def marlin_weights_loader(): + return MarlinWeightsLoader(bits=4, is_marlin_24=False) + + +dummy_file_system = { + "test_weights": { + "layer.0.weight": torch.tensor( + [ + [1, 2], + [3, 4], + ], + dtype=torch.float32, + ), + }, + "test_weights_2": { + "layer.1337.weight": torch.tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_weights_col_packed": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_multi_weights_col": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_weights_row": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + }, + "test_get_weights_row_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_multi_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_packed_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_packed_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_weights_row_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_col_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_weights_col_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_weights_row_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, + "test_get_multi_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, + "test_get_weights_col_packed_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, +} + + +class MockSlice: + def __init__(self, tensor): + self.tensor = tensor + + def get_shape(self): + return self.tensor.shape + + def __getitem__(self, idx): + return self.tensor[idx] + + +def mock_get_slice(tensor_name, filename): + tensor = dummy_file_system[filename][tensor_name] + return MockSlice(tensor) + + +def mock_handle(filename, device, dtype): + return SimpleNamespace( + get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename) + ) + + +class MockSafeOpen: + def __init__(self, filename, framework, dummy_fs): + self.filename = filename + self.framework = framework + self.dummy_fs = dummy_fs + + def keys(self): + return list(self.dummy_fs[self.filename].keys()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class MockWeights(Weights): + def __init__( + self, + filenames: List[Union[Path, str]], + device, + dtype, + process_group, + dummy_fs, + aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None, + weights_loader: Optional[WeightsLoader] = None, + ): + routing = {} + self.dummy_fs = dummy_fs + for filename in filenames: + with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self.prefix = prefix + self.weights_loader = ( + # We don't need to get linear layers, so just wrap raw tensors. + DefaultWeightsLoader(lambda x: x) + if weights_loader is None + else weights_loader + ) + self._handles = {} + + def _get_handle(self, filename: Union[Path, str]): + if filename in self._handles: + return self._handles[filename] + else: + handle = mock_handle(filename, self.device, self.dtype) + self._handles[filename] = handle + return handle + + def get_shape(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).tensor + + +dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1) + + +def test_weights(): + weights = MockWeights( + [ + "test_weights", + "test_weights_2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + assert weights.get_shape("layer.0.weight") == (2, 2) + assert weights.get_tensor("layer.1337.weight").shape == (2, 4) + + +def test_get_tensor(): + weights = MockWeights( + [ + "test_weights", + "test_weights_2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + assert torch.allclose( + weights.get_tensor("layer.0.weight"), + torch.tensor( + [ + [1, 2], + [3, 4], + ], + dtype=torch.float32, + ), + ) + assert torch.allclose( + weights.get_tensor("layer.1337.weight"), + torch.tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed_block_size(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + block_sizes = 2 + + w = weights.get_weights_col_packed( + prefix=prefix, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed_block_size_arr(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + block_sizes = [1, 1] + + w = weights.get_weights_col_packed( + prefix=prefix, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_multi_weights_col(): + weights = MockWeights( + [ + "test_get_multi_weights_col", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight", "weight"] + + w = weights.get_multi_weights_col( + prefixes=prefixes, + dim=0, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_row(): + weights = MockWeights( + [ + "test_get_weights_row", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + + w = weights.get_weights_row( + prefix=prefix, + ) + + assert torch.allclose( + w, + torch.tensor( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + dtype=torch.float32, + ), + ) + + +# test_get_weights_col + + +def test_get_weights_col_awq(gptq_weights_loader_awq): + weights = MockWeights( + [ + "test_get_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, + ) + + prefix = "weight" + + w = weights.get_weights_col( + prefix=prefix, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor( + [[100.0, 100.0], [100.0, 100.0]], + dtype=torch.float16, + ), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_awq_kernel=True, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_gtpq(gptq_weights_loader): + weights = MockWeights( + [ + "test_get_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, + ) + + prefix = "weight" + + w = weights.get_weights_col( + prefix=prefix, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_awq_kernel=False, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), + ) + + prefix = "weight" + + w = weights.get_weights_col( + prefix=prefix, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_col_marlin(marlin_weights_loader): + weights = MockWeights( + [ + "test_get_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, + ) + + prefix = "weight" + + w = weights.get_weights_col( + prefix=prefix, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_weights_col_packed + + +def test_get_weights_col_packed_awq(gptq_weights_loader_awq): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, + ) + + prefix = "weight" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + block_sizes=block_sizes, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_awq_kernel=True, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +@pytest.mark.skip(reason="Review expected functionality") +def test_get_weights_col_packed_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_packed_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), + ) + + prefix = "weight" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + block_sizes=block_sizes, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_col_packed_gptq(gptq_weights_loader): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, + ) + + prefixes = ["weight"] + + w = weights.get_multi_weights_col( + prefixes=prefixes, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_awq_kernel=False, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_packed_marlin(marlin_weights_loader): + weights = MockWeights( + [ + "test_get_weights_col_packed_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, + ) + + prefix = "weight" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + print(expected_weight) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_multi_weights_col + + +def test_get_multi_weights_col_awq(gptq_weights_loader_awq): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, + ) + + prefixes = ["weight"] + + w = weights.get_multi_weights_col( + prefixes=prefixes, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_awq_kernel=True, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_multi_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), + ) + + prefix = "weight" + + try: + weights.get_multi_weights_col( + prefixes=[prefix], + dim=0, + ) + except ValueError as e: + assert e.args[0] == "get_multi_weights_col is not supported for exl2" + + +def test_get_multi_weights_col_gptq(gptq_weights_loader): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, + ) + + prefixes = ["weight"] + + w = weights.get_multi_weights_col( + prefixes=prefixes, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_awq_kernel=False, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_marlin(marlin_weights_loader): + weights = MockWeights( + [ + "test_get_multi_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, + ) + + prefix = "weight" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_weights_row + + +def test_get_weights_row_awq(gptq_weights_loader_awq): + weights = MockWeights( + [ + "test_get_weights_row_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, + ) + + prefix = "weight" + + w = weights.get_weights_row( + prefix=prefix, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_awq_kernel=True, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_row_exl2(): + weights = MockWeights( + [ + "test_get_weights_row_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), + ) + + prefix = "weight" + + w = weights.get_weights_row( + prefix=prefix, + ) + print(w) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_row_gptq(gptq_weights_loader): + weights = MockWeights( + [ + "test_get_weights_row_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, + ) + + prefix = "weight" + + w = weights.get_weights_row( + prefix=prefix, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_awq_kernel=False, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_row_marlin(marlin_weights_loader): + weights = MockWeights( + [ + "test_get_weights_row_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, + ) + + prefix = "weight" + + w = weights.get_weights_row( + prefix=prefix, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" diff --git a/backends/gaudi/server/text_generation_server/__init__.py b/backends/gaudi/server/text_generation_server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backends/gaudi/server/text_generation_server/adapters/__init__.py b/backends/gaudi/server/text_generation_server/adapters/__init__.py new file mode 100644 index 000000000..8697cb9ee --- /dev/null +++ b/backends/gaudi/server/text_generation_server/adapters/__init__.py @@ -0,0 +1,13 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/__init__.py +# License: Apache License Version 2.0, January 2004 + +from text_generation_server.adapters.weights import ( + AdapterBatchData, + AdapterBatchMetadata, +) + +__all__ = [ + "AdapterBatchData", + "AdapterBatchMetadata", +] diff --git a/backends/gaudi/server/text_generation_server/adapters/config.py b/backends/gaudi/server/text_generation_server/adapters/config.py new file mode 100644 index 000000000..b7e270900 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/adapters/config.py @@ -0,0 +1,30 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/config.py +# License: Apache License Version 2.0, January 2004 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Set, Tuple + +import torch + +from text_generation_server.adapters.weights import AdapterWeights + + +@dataclass +class ModuleMap: + module_name: str + module_weights: Dict[str, Tuple[torch.Tensor, str]] + + +@dataclass +class AdapterConfig(ABC): + base_model_name_or_path: str + + @abstractmethod + def map_weights_for_model( + self, + adapter_weights: Dict[int, AdapterWeights], + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + pass diff --git a/backends/gaudi/server/text_generation_server/adapters/lora.py b/backends/gaudi/server/text_generation_server/adapters/lora.py new file mode 100644 index 000000000..a00338e7c --- /dev/null +++ b/backends/gaudi/server/text_generation_server/adapters/lora.py @@ -0,0 +1,471 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/lora.py +# License: Apache License Version 2.0, January 2004 + +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple, Type, Union + +import torch +from peft import LoraConfig as _LoraConfig +from torch.distributed import ProcessGroup + +from text_generation_server.adapters.config import AdapterConfig, ModuleMap + +from text_generation_server.adapters.weights import ( + AdapterBatchMetadata, + AdapterWeights, + BatchAdapterWeights, +) +from text_generation_server.utils.sgmv import ( + BGMV_MAX_RANK, + MAX_RANK_CUSTOM, + get_tmp_tensors, + orient_for_rank, + pad_rank, + use_cutlass_shrink, +) + + +def get_start_stop_idxs_for_rank(offset, size, rank, world_size): + block_size = size // world_size + start = offset + rank * block_size + stop = offset + (rank + 1) * block_size + return start, stop + + +def shard_on_dim( + t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup +): + world_size = process_group.size() + rank = process_group.rank() + + size = t.shape[dim] + start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) + + if dim == 0: + tensor = t[start:stop] + elif dim == 1: + tensor = t[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + + return tensor + + +def shard_lora_weights( + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + split_dim: int, + process_group: ProcessGroup, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # [hidden_size, r] + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a + ] + + # [r, hidden_size] + weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b] + + return weights_a, weights_b + + +@dataclass +class LoraConfig(AdapterConfig): + r: int + target_modules: Optional[Union[List[str], str]] + fan_in_fan_out: bool + lora_alpha: int + use_rslora: bool + + def map_weights_for_model( + self, + adapter_weights: Dict[int, AdapterWeights], + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + adapter_weight_names = set() + module_map = {} + for weight_name in weight_names: + lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" + lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" + if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: + continue + + module_map[weight_name] = { + "lora_A": (adapter_weights[lora_a_name], lora_a_name), + "lora_B": (adapter_weights[lora_b_name], lora_b_name), + } + adapter_weight_names.add(lora_a_name) + adapter_weight_names.add(lora_b_name) + return module_map, adapter_weight_names + + @classmethod + def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": + hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) + return cls( + base_model_name_or_path=hf_config.base_model_name_or_path, + r=hf_config.r, + target_modules=hf_config.target_modules, + fan_in_fan_out=hf_config.fan_in_fan_out, + lora_alpha=hf_config.lora_alpha, + use_rslora=( + hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False + ), + ) + + +class LoraWeights(AdapterWeights): + """LoRA weights for a single adapter merged across all layers.""" + + def __init__( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + adapter_config: LoraConfig, + ): + self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 + self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + + self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._is_transposed = False + + # [num_layers, hidden_size, r] + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + self._weights_a = torch.stack(weights_a) + + # [num_layers, r, hidden_size] + self._weights_b = torch.stack(weights_b) + + self.adapter_config = adapter_config + + @property + def weights_a(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_b + + @property + def weights_a_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_b + + def _transpose_weights(self): + if self._use_cutlass_shrink: + # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation + self._weights_a = self._weights_a.transpose(1, 2).contiguous() + self._weights_b = self._weights_b.transpose(1, 2).contiguous() + self._is_transposed = not self._is_transposed + + @classmethod + def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: + return [BatchLoraWeights] + + # prepare pre-loaded lora weights for use in the model. + # + # this method processes and organizes lora weights for a specific layer type across all layers: + # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor. + # - retrieves weights from `module_map` based on the `layer_type`. + # - processes `nlayers` number of layers. + # - converts weights to the specified `dtype`. + # - shards weights across `world_size` number of processes using the `process_group`. + # - maps weights to specific layers using `target_to_layer`. + # - tracks `unused_weight_names` to identify any unused weights. + # + # the method handles weight transposition, scaling, and padding to ensure compatibility + # with SGMV or BGMV operations. + @classmethod + def prepare_weights( + cls, + config: LoraConfig, + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + nlayers: int, + dtype: torch.dtype, + world_size: int, + process_group: ProcessGroup, + target_to_layer: Dict[str, Tuple[str, torch.Tensor]], + ) -> Optional[AdapterWeights]: + lora_a_list = [None] * nlayers + lora_b_list = [None] * nlayers + + for layer_id in range(nlayers): + key = (layer_id, layer_type) + weight_name, layer = target_to_layer[key] + base_weight = layer.base_layer.linear.weight + base_device = base_weight.device + + if weight_name not in module_map: + # There is no LoRA weight for this layer type in the adapter + return None + + lora_a, lora_a_name = module_map[weight_name]["lora_A"] + lora_a = lora_a.to(base_device, dtype) + + lora_b, lora_b_name = module_map[weight_name]["lora_B"] + lora_b = lora_b.to(base_device, dtype) + + scale = get_scaling_factor( + config.lora_alpha, + config.r, + uses_rslora=config.use_rslora, + ) + + unused_weight_names.discard(lora_a_name) + unused_weight_names.discard(lora_b_name) + + # Merge scaling factor into lora_b due to associativity of matrix multiplication: + # (A * B) * C = A * (B * C) + lora_a_list[layer_id] = lora_a.transpose(0, 1) + lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale + + # pad lora ranks to be compatible with sgmv + lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] + lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] + + if lora_a_list: + # update rank if it was padded + padded_rank = lora_a_list[0].size(1) + config.r = padded_rank + + return LoraWeights( + *shard_lora_weights( + weights_a=lora_a_list, + weights_b=lora_b_list, + split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1, + process_group=process_group, + ), + config, + ) + + +@dataclass +class RankSegments: + rank: int + + lora_a_ptr: torch.Tensor + lora_b_ptr: torch.Tensor + + # prefill (sgmv) + tmp_shrink: torch.Tensor + tmp_expand: torch.Tensor + segment_starts: torch.Tensor + segment_ends: torch.Tensor + + # decode (bgmv) + indices: torch.Tensor + + +@dataclass +class BatchLoraWeights(BatchAdapterWeights): + lora_a: Dict[int, torch.Tensor] + lora_b: Dict[int, torch.Tensor] + adapter_index_configs: Dict[int, LoraConfig] + rank_data: Dict[int, RankSegments] + use_sgmv: bool + + def has_adapter(self, adapter_index: int) -> bool: + return adapter_index in self.adapter_index_configs + + def can_vectorize(self, pg: ProcessGroup) -> bool: + return all( + rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + for rank_data in self.rank_data.values() + ) + + @classmethod + def load( + self, + adapter_weights: Dict[int, AdapterWeights], + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Optional["BatchLoraWeights"]: + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} + adapter_weights = { + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) + } + if not adapter_weights: + return None + + first_weights = next(iter(adapter_weights.values())) + device = first_weights.weights_a.device + segment_indices = meta.segment_indices + + lora_a = { + idx: adapter_weights[idx].weights_a + for idx in segment_indices + if idx in adapter_weights + } + lora_b = { + idx: adapter_weights[idx].weights_b + for idx in segment_indices + if idx in adapter_weights + } + + max_rank = max( + ( + adapter_weights[idx].lora_a_r + for idx in segment_indices + if idx in adapter_weights + ), + default=0, + ) + + if prefill or max_rank > BGMV_MAX_RANK: + use_sgmv = True + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + else: + use_sgmv = False + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a_t.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b_t.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + + adapter_index_configs = { + idx: adapter_weights[idx].adapter_config + for idx in segment_indices + if idx in adapter_weights + } + + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} + + rank_indices = defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx not in adapter_weights: + continue + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + + if prefill_head_indices is not None: + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + for head_index in prefill_head_indices: + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + if head_index < meta.adapter_segments[j]: + prefill_head_segment_ends[-1] += 1 + else: + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) + j += 1 + + rank_data = {} + for rank, indices in rank_indices.items(): + tmp_shrink = None + tmp_expand = None + segment_starts = None + segment_ends = None + batch_indices = None + + if use_sgmv: + lora_a_ptr_indices = lora_a_ptr[indices] + tmp_shrink, tmp_expand = get_tmp_tensors( + lora_a_ptr_indices.size(0), rank, device + ) + segment_starts = meta.adapter_segments[indices] + segment_ends = meta.adapter_segments[[i + 1 for i in indices]] + if prefill_head_indices is not None: + for i, segment_index in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_index] + segment_ends[i] = prefill_head_segment_ends[segment_index] + else: + rank_indices = set(indices) + batch_indices = [ + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() + ] + batch_indices = [ + idx if idx in rank_indices else -1 for idx in batch_indices + ] + batch_indices = torch.tensor( + batch_indices, dtype=torch.int64, device=device + ) + + rank_data[rank] = RankSegments( + rank=rank, + tmp_shrink=tmp_shrink, + tmp_expand=tmp_expand, + lora_a_ptr=lora_a_ptr[indices], + lora_b_ptr=lora_b_ptr[indices], + segment_starts=segment_starts, + segment_ends=segment_ends, + indices=batch_indices, + ) + + return BatchLoraWeights( + lora_a=lora_a, + lora_b=lora_b, + adapter_index_configs=adapter_index_configs, + rank_data=rank_data, + use_sgmv=use_sgmv, + ) + + +def get_scaling_factor( + lora_alpha: int, + r: int, + uses_rslora: bool = False, +) -> float: + """Computes the scaling factor for the lora weights.""" + if uses_rslora: + return lora_alpha / (r**0.5) + return lora_alpha / r + + +def _convert_lora(v: AdapterWeights) -> AdapterWeights: + if hasattr(v, "lora_weights"): + return v.lora_weights + return v diff --git a/backends/gaudi/server/text_generation_server/adapters/weights.py b/backends/gaudi/server/text_generation_server/adapters/weights.py new file mode 100644 index 000000000..da75dbcdf --- /dev/null +++ b/backends/gaudi/server/text_generation_server/adapters/weights.py @@ -0,0 +1,146 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/weights.py +# License: Apache License Version 2.0, January 2004 + +from abc import ABC, abstractclassmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Type + +import torch + + +@dataclass +class AdapterBatchMetadata: + # [batch_size] + adapter_indices: torch.Tensor + + # [num_adapters] + adapter_set: Set[int] + + # [num_segments + 1] + adapter_segments: torch.Tensor + + # [num_segments] + # maps from segment index to adapter index, i.e.: + # segment_indices[s] == adapter_indices[i] + segment_indices: List[int] + + +class AdapterWeights(ABC): + @abstractclassmethod + def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]: + pass + + @property + def speculative_tokens(self) -> int: + return 0 + + +class BatchAdapterWeights(ABC): + @abstractclassmethod + def has_adapter(self, adapter_index: int) -> bool: + pass + + @abstractclassmethod + def load( + cls, + adapter_weights: Dict[int, AdapterWeights], + meta: "AdapterBatchMetadata", + prefill: bool, + prefill_head_indices: torch.Tensor, + ) -> Optional["BatchAdapterWeights"]: + pass + + +class LayerAdapterWeights: + """Adapter weights that apply to a particular layer.""" + + def __init__(self): + self.adapter_weights: Dict[int, AdapterWeights] = {} + + def add_adapter(self, adapter_idx: int, weights: AdapterWeights): + self.adapter_weights[adapter_idx] = weights + + def remove_adapter(self, adapter_idx: int): + if adapter_idx not in self.adapter_weights: + return + del self.adapter_weights[adapter_idx] + + def is_empty(self) -> bool: + return len(self.adapter_weights) == 0 + + def get_data( + self, + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Dict[str, BatchAdapterWeights]: + # bucket adapters by batch class + adapter_batch_types: Dict[ + Type[BatchAdapterWeights], Dict[int, AdapterWeights] + ] = defaultdict(dict) + for adapter_index, adapter_weights in self.adapter_weights.items(): + for batch_type in adapter_weights.get_batch_types(): + adapter_batch_types[batch_type][adapter_index] = adapter_weights + + batch_data = {} + for batch_type, adapter_weights in adapter_batch_types.items(): + batched_weights = batch_type.load( + adapter_weights, meta, prefill, prefill_head_indices + ) + if batched_weights is not None: + batch_data = batched_weights + return batch_data + + +@dataclass +class AdapterBatchData: + meta: AdapterBatchMetadata + + # layer type -> adapter type -> batch weight data + data: Dict[str, Dict[str, BatchAdapterWeights]] + + prefill: bool + + @staticmethod + def from_meta( + meta: AdapterBatchMetadata, + weights: Dict[str, LayerAdapterWeights], + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> "AdapterBatchData": + data = {} + for k, v in weights.items(): + if v.is_empty(): + continue + data[k] = v.get_data( + meta, prefill, prefill_head_indices if k == "lm_head" else None + ) + return AdapterBatchData(meta=meta, data=data, prefill=prefill) + + def ranks(self) -> Set[int]: + # TODO(travis): refactor to be less coupled to lora implementation + ranks = set() + for lora_data in self.data.values(): + if lora_data is None: + continue + + for rank_data in lora_data.rank_data.values(): + ranks.add(rank_data.rank) + + return ranks + + def layer_names(self) -> Set[str]: + return set(self.data.keys()) + + def adapter_keys(self) -> Set[str]: + adapter_keys = set() + for layer_data in self.data.values(): + adapter_keys.update(layer_data.keys()) + return adapter_keys + + @property + def max_rank(self) -> int: + ranks = self.ranks() + return max(ranks) if len(ranks) > 0 else 0 diff --git a/backends/gaudi/server/text_generation_server/cache.py b/backends/gaudi/server/text_generation_server/cache.py new file mode 100644 index 000000000..4504733e5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/cache.py @@ -0,0 +1,34 @@ +import torch + +from typing import Dict, Optional, TypeVar + +from text_generation_server.models.types import Batch + +B = TypeVar("B", bound=Batch) + + +class Cache: + def __init__(self): + self.cache: Dict[int, B] = {} + + def pop(self, batch_id: int) -> Optional[B]: + return self.cache.pop(batch_id, None) + + def set(self, entry: B): + if entry is not None: + self.cache[entry.batch_id] = entry + + def delete(self, batch_id: int): + batch = self.pop(batch_id) + if batch is not None: + del batch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def clear(self): + keys = list(self.cache.keys()) + for k in keys: + self.delete(k) + + def __len__(self): + return len(self.cache.keys()) diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py new file mode 100644 index 000000000..f9b4caa98 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -0,0 +1,426 @@ +import os +import psutil +import signal +import sys +import typer + +from pathlib import Path +from loguru import logger +from typing import Optional +from enum import Enum +from huggingface_hub import hf_hub_download +from text_generation_server.utils.adapter import parse_lora_adapters + + +app = typer.Typer() + + +class Quantization(str, Enum): + bitsandbytes = "bitsandbytes" + bitsandbytes_nf4 = "bitsandbytes-nf4" + bitsandbytes_fp4 = "bitsandbytes-fp4" + gptq = "gptq" + awq = "awq" + eetq = "eetq" + exl2 = "exl2" + fp8 = "fp8" + marlin = "marlin" + + +class Dtype(str, Enum): + float16 = "float16" + bloat16 = "bfloat16" + + +@app.command() +def serve( + model_id: str, + revision: Optional[str] = None, + sharded: bool = False, + quantize: Optional[Quantization] = None, + speculate: Optional[int] = None, + dtype: Optional[Dtype] = None, + trust_remote_code: bool = False, + uds_path: Path = "/tmp/text-generation-server", + logger_level: str = "INFO", + json_output: bool = False, + otlp_endpoint: Optional[str] = None, + otlp_service_name: str = "text-generation-inference.server", + max_input_tokens: Optional[int] = None, +): + if sharded: + # assert ( + # os.getenv("RANK", None) is not None + # ), "RANK must be set when sharded is True" + assert ( + os.getenv("WORLD_SIZE", None) is not None + ), "WORLD_SIZE must be set when sharded is True" + assert ( + os.getenv("MASTER_ADDR", None) is not None + ), "MASTER_ADDR must be set when sharded is True" + assert ( + os.getenv("MASTER_PORT", None) is not None + ), "MASTER_PORT must be set when sharded is True" + + # Remove default handler + logger.remove() + logger.add( + sys.stdout, + format="{message}", + filter="text_generation_server", + level=logger_level, + serialize=json_output, + backtrace=True, + diagnose=False, + ) + + # Import here after the logger is added to log potential import exceptions + from text_generation_server import server + from text_generation_server.tracing import setup_tracing + + # Setup OpenTelemetry distributed tracing + if otlp_endpoint is not None: + setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) + + lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) + + # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled + # and warn the user + if lora_adapters: + logger.warning("LoRA adapters enabled (experimental feature).") + + if "CUDA_GRAPHS" in os.environ: + logger.warning( + "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs." + ) + global CUDA_GRAPHS + CUDA_GRAPHS = None + + # Downgrade enum into str for easier management later on + quantize = None if quantize is None else quantize.value + dtype = "bfloat16" if dtype is None else dtype.value + logger.info(f"quantize={quantize}") + if dtype is not None and quantize not in { + None, + "bitsandbytes", + "bitsandbytes-nf4", + "bitsandbytes-fp4", + }: + raise RuntimeError( + "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." + ) + + logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) + + if sharded: + tgi_file = Path(__file__).resolve().parent / "tgi_service.py" + num_shard = int(os.getenv("WORLD_SIZE", "1")) + logger.info("CLI SHARDED = {}".format(num_shard)) + import subprocess + + cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" + cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" + cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" + cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" + if speculate is not None: + cmd += f"--speculate {speculate}" + logger.info("CLI server start deepspeed ={} ".format(cmd)) + sys.stdout.flush() + sys.stderr.flush() + with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: + do_terminate = False + current_handler = signal.getsignal(signal.SIGTERM) + def terminate_handler(sig, frame): + nonlocal do_terminate + do_terminate = True + if callable(current_handler): + current_handler(sig, frame) + + signal.signal(signal.SIGTERM, terminate_handler) + + finished = False + while not finished: + try: + if do_terminate: + parent = psutil.Process(proc.pid) + all_procs = parent.children(recursive=True) + [parent] + for p in all_procs: + try: + p.terminate() + except psutil.NoSuchProcess: + pass + _, alive = psutil.wait_procs(all_procs, timeout=30) + for p in alive: + p.kill() + + do_terminate = False + + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + pass + else: + finished = True + + sys.stdout.flush() + sys.stderr.flush() + if proc.returncode != 0: + logger.error(f"{cmd} exited with status = {proc.returncode}") + return proc.returncode + else: + server.serve( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + max_input_tokens, + ) + + +@app.command() +def download_weights( + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", + auto_convert: bool = True, + logger_level: str = "INFO", + json_output: bool = False, + trust_remote_code: bool = False, + merge_lora: bool = False, +): + # Remove default handler + logger.remove() + logger.add( + sys.stdout, + format="{message}", + filter="text_generation_server", + level=logger_level, + serialize=json_output, + backtrace=True, + diagnose=False, + ) + + # Import here after the logger is added to log potential import exceptions + from text_generation_server import utils + + # Test if files were already download + try: + utils.weight_files(model_id, revision, extension) + logger.info("Files are already present on the host. " "Skipping download.") + return + # Local files not found + except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): + pass + + is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None + + if not is_local_model: + # TODO: maybe reverse the default value of merge_lora? + # currently by default we don't merge the weights with the base model + if merge_lora: + try: + hf_hub_download( + model_id, revision=revision, filename="adapter_config.json" + ) + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + is_local_model = True + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + else: + try: + utils.peft.download_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + except Exception: + pass + + try: + import json + + config = hf_hub_download( + model_id, revision=revision, filename="config.json" + ) + with open(config, "r") as f: + config = json.load(f) + + base_model_id = config.get("base_model_name_or_path", None) + if base_model_id and base_model_id != model_id: + try: + logger.info(f"Downloading parent model {base_model_id}") + download_weights( + model_id=base_model_id, + revision="main", + extension=extension, + auto_convert=auto_convert, + logger_level=logger_level, + json_output=json_output, + trust_remote_code=trust_remote_code, + ) + except Exception: + pass + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + + # Try to download weights from the hub + try: + filenames = utils.weight_hub_files(model_id, revision, extension) + utils.download_weights(filenames, model_id, revision) + # Successfully downloaded weights + return + + # No weights found on the hub with this extension + except utils.EntryNotFoundError as e: + # Check if we want to automatically convert to safetensors or if we can use .bin weights instead + if not extension == ".safetensors" or not auto_convert: + raise e + + elif (Path(model_id) / "adapter_config.json").exists(): + # Try to load as a local PEFT model + try: + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + elif (Path(model_id) / "config.json").exists(): + # Try to load as a local Medusa model + try: + import json + + config = Path(model_id) / "config.json" + with open(config, "r") as f: + config = json.load(f) + + base_model_id = config.get("base_model_name_or_path", None) + if base_model_id: + try: + logger.info(f"Downloading parent model {base_model_id}") + download_weights( + model_id=base_model_id, + revision="main", + extension=extension, + auto_convert=auto_convert, + logger_level=logger_level, + json_output=json_output, + trust_remote_code=trust_remote_code, + ) + except Exception: + pass + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + + # Try to see if there are local pytorch weights + try: + # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE + try: + local_pt_files = utils.weight_files(model_id, revision, ".bin") + except Exception: + local_pt_files = utils.weight_files(model_id, revision, ".pt") + + # No local pytorch weights + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + if extension == ".safetensors": + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Downloading PyTorch weights." + ) + + # Try to see if there are pytorch weights on the hub + pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") + # Download pytorch weights + local_pt_files = utils.download_weights(pt_filenames, model_id, revision) + + if auto_convert: + if not trust_remote_code: + logger.warning( + "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because " + "Pickle files are unsafe and can essentially contain remote code execution!" + "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety", + ) + + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Converting PyTorch weights to safetensors." + ) + + # Safetensors final filenames + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" + for p in local_pt_files + ] + try: + import transformers + import json + + if is_local_model: + config_filename = os.path.join(model_id, "config.json") + else: + config_filename = hf_hub_download( + model_id, revision=revision, filename="config.json" + ) + with open(config_filename, "r") as f: + config = json.load(f) + architecture = config["architectures"][0] + + class_ = getattr(transformers, architecture) + + # Name for this varible depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + + except Exception: + discard_names = [] + # Convert pytorch weights to safetensors + utils.convert_files(local_pt_files, local_st_files, discard_names) + + +@app.command() +def quantize( + model_id: str, + output_dir: str, + revision: Optional[str] = None, + logger_level: str = "INFO", + json_output: bool = False, + trust_remote_code: bool = False, + upload_to_model_id: Optional[str] = None, + percdamp: float = 0.01, + act_order: bool = False, + groupsize: int = 128, +): + if revision is None: + revision = "main" + download_weights( + model_id=model_id, + revision=revision, + logger_level=logger_level, + json_output=json_output, + ) + from text_generation_server.layers.gptq.quantize import quantize + + quantize( + model_id=model_id, + bits=4, + groupsize=groupsize, + output_dir=output_dir, + revision=revision, + trust_remote_code=trust_remote_code, + upload_to_model_id=upload_to_model_id, + percdamp=percdamp, + act_order=act_order, + sym=True, + ) + + +if __name__ == "__main__": + app() diff --git a/backends/gaudi/server/text_generation_server/habana_quantization_env.py b/backends/gaudi/server/text_generation_server/habana_quantization_env.py new file mode 100644 index 000000000..e942fdcf0 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/habana_quantization_env.py @@ -0,0 +1,45 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import os +import habana_frameworks.torch as htorch + +quant_config = os.getenv("QUANT_CONFIG", "") +is_quantization_enabled = quant_config != "" + +if is_quantization_enabled: + os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") + os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") + os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") + os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") + os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") + os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") + + +def patch_scoped_linear_all_reduce(model): + from deepspeed.module_inject.layers import LinearAllreduce + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + patch_scoped_linear_all_reduce(module) + + +def setup_quantization(model): + if is_quantization_enabled: + htorch.core.quantization._mark_params_as_const(model) + htorch.core.quantization._check_params_as_const(model) + htorch.core.hpu_initialize(model) + return model + + +def prepare_model_for_quantization(model): + if is_quantization_enabled: + if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]: + patch_scoped_linear_all_reduce(model) + from neural_compressor.torch.quantization import FP8Config, convert + + config = FP8Config.from_json_file(quant_config) + model = convert(model, config) + return model diff --git a/backends/gaudi/server/text_generation_server/interceptor.py b/backends/gaudi/server/text_generation_server/interceptor.py new file mode 100644 index 000000000..05339282b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/interceptor.py @@ -0,0 +1,44 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import torch +import grpc + +from google.rpc import status_pb2, code_pb2 +from grpc_status import rpc_status +from grpc_interceptor.server import AsyncServerInterceptor +from loguru import logger +from typing import Callable, Any +import traceback +import os + + +class ExceptionInterceptor(AsyncServerInterceptor): + async def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + try: + response = method(request_or_iterator, context) + return await response + except Exception as err: + trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else '' + method_name = method_name.split("/")[-1] + logger.exception(f"Method {method_name} encountered an error.") + + # Runtime Error cannot be recovered from + if isinstance(err, RuntimeError): + exit(1) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + from .utils.debug import dbg_trace + dbg_trace('EXCEPTION', traceback.format_exc()) + await context.abort_with_status( + rpc_status.to_status( + status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace) + ) + ) diff --git a/backends/gaudi/server/text_generation_server/layers/__init__.py b/backends/gaudi/server/text_generation_server/layers/__init__.py new file mode 100644 index 000000000..0000ca915 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/__init__.py @@ -0,0 +1,34 @@ +from text_generation_server.layers.tensor_parallel import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, +) +from text_generation_server.layers.linear import ( + get_linear, + FastLinear, +) +from text_generation_server.layers.speculative import SpeculativeHead + +# Just to add the `load` methods. +from text_generation_server.layers.layernorm import load_layer_norm +from text_generation_server.layers.conv import load_conv2d + +from text_generation_server.layers.lora import ( + LoraLinear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) + +__all__ = [ + "get_linear", + "FastLinear", + "TensorParallelColumnLinear", + "TensorParallelRowLinear", + "TensorParallelEmbedding", + "SpeculativeHead", + "LoraLinear", + "TensorParallelMultiAdapterLinear", + "TensorParallelAdapterRowLinear", + "load_layer_norm", + "load_conv2d", +] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py new file mode 100644 index 000000000..4f2b98075 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -0,0 +1,43 @@ +from text_generation_server.utils.import_utils import SYSTEM +import os + +from .common import Seqlen + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") +if SYSTEM == "cuda": + from .cuda import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, + ) +elif SYSTEM == "rocm": + from .rocm import ( + attention, + paged_attention, + reshape_and_cache, + PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, + ) +elif SYSTEM == "ipex": + from .ipex import ( + attention, + paged_attention, + reshape_and_cache, + PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, + ) +else: + raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") + + +__all__ = [ + "attention", + "paged_attention", + "reshape_and_cache", + "PREFILL_IN_KV_CACHE", + "SUPPORTS_WINDOWING", + "Seqlen", +] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py new file mode 100644 index 000000000..d6e512c01 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import ATTENTION +import torch +from typing import Optional + + +if ATTENTION in {"flashinfer", "flashdecoding"}: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + prefix_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + max_q: int + max_k: int + + def __init__( + self, + input_lengths, + prefix_lengths, + cu_seqlen_q=None, + max_q=None, + max_k=None, + ): + self.input_lengths = input_lengths + self.prefix_lengths = prefix_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + max_q = 1 + else: + assert max_q is not None + assert max_k is not None + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + total = self.input_lengths + self.prefix_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + self.max_q = max_q + self.max_k = max_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + +else: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + prefix_lengths: torch.Tensor + cu_seqlen_q: torch.Tensor + max_q: int + max_k: int + + def clamp(self, max): + if SYSTEM == "rocm": + return self + raise NotImplementedError("Not implemented seqlen for paged") + return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/cuda.py b/backends/gaudi/server/text_generation_server/layers/attention/cuda.py new file mode 100644 index 000000000..51af928d5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/cuda.py @@ -0,0 +1,357 @@ +import torch +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) +from text_generation_server.layers.attention import Seqlen +from typing import Optional + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +_PARTITION_SIZE = 512 + +try: + from vllm._C import cache_ops +except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + if ATTENTION in {"flashdecoding", "flashinfer"}: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + seqlen: Seqlen, + max_s: int, + softcap: Optional[float] = None, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + # block_size = value_cache.shape[3] + block_size = BLOCK_SIZE + num_seqs, num_heads, head_size = query.shape + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import decode_state + + return decode_state.get().forward( + query.contiguous(), + paged_kv_cache=(key_cache, value_cache), + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + elif ATTENTION == "flashdecoding": + max_q = 1 + max_k = max_s + import flash_attn_2_cuda + + # TODO fixme when flash contains the fix. + # Number of splits is not correctly handled + # by the current path + # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 + # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + if softcap is None: + softcap = 0.0 + out = flash_attn_2_cuda.varlen_fwd( + query, + key_cache, + value_cache, + None, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, # pad_k + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + softcap, + False, # return softmax + None, # generator + ) + return out[0] + else: + if softcap is not None: + raise RuntimeError("Paged attention doesn't support softcapping") + input_lengths = seqlen.input_lengths + from vllm._C import ops + + out = torch.empty_like(query) + + use_v1 = max_s <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 + ) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + return out + + +try: + is_ampere_or_newer = major >= 8 and minor >= 0 + if not is_ampere_or_newer: + raise ImportError("FlashAttention only supports Ampere GPUs or newer.") + + import flash_attn_2_cuda + + V2 = True +except ImportError: + try: + import flash_attn_cuda + + V2 = False + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + + +SUPPORTS_WINDOWING = V2 + +if ATTENTION == "flashinfer": + + def attention( + q: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale, + window_size_left=-1, + causal=True, + softcap=0.0, + ): + from text_generation_server.layers.attention.flashinfer import ( + prefill_with_paged_kv_state, + ) + + return prefill_with_paged_kv_state.get().forward( + q.contiguous(), + causal=causal, + paged_kv_cache=(key_cache, value_cache), + logits_soft_cap=softcap, + sm_scale=softmax_scale, + window_left=window_size_left, + ) + +elif V2: + + def attention( + q, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale, + window_size_left=-1, + causal=True, + softcap=0.0, + ): + out = torch.empty_like(q) + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + return flash_attn_2_cuda.varlen_fwd( + q, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + None, + block_tables, + None, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + softcap, + False, + None, + )[0] + +else: + + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap=None, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + if softcap is not None: + raise NotImplementedError("softcap is only available with flash attn v2") + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + out = torch.empty_like(q) + flash_attn_cuda.fwd( + q, + k, + v, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + False, + 0, + None, + ) + return out + + +# Prefill in the cache with every kind of attention, unless we +# have a configuration that requires flash-attention v1, which +# does not support block tables. +PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py b/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py new file mode 100644 index 000000000..3a6f9a730 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets( + philox_seed, philox_offset, dropout_p, m, n, stride + ).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn( + bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" + ) + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = ( + batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + + start_n + - BLOCK_N + ) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, BLOCK_N) + ) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N + ) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = ( + off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + ) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k + ) + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, n_full_blocks) + ) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full( + (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 + ) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + + def grid(META): + return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py b/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py new file mode 100644 index 000000000..d603c6f5f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py @@ -0,0 +1,251 @@ +from typing import Optional +from contextvars import ContextVar +from contextlib import contextmanager + +import flashinfer +import torch + +prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( + "prefill_state" +) + +prefill_with_paged_kv_state: ContextVar[ + flashinfer.BatchPrefillWithPagedKVCacheWrapper +] = ContextVar("prefill_with_paged_kv_state") + +decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( + "decode_state" +) + +workspace: Optional[torch.Tensor] = None + + +def get_workspace(device): + """Get shared flashinfer workspace.""" + global workspace + if workspace is None: + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + return workspace + + +def create_prefill_with_paged_kv_state( + *, + device: torch.device, +): + """Create a prefill state that uses the KV cache.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_with_paged_kv_state( + *, + state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, + block_tables: torch.Tensor, + cu_seqlens: torch.Tensor, + input_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + dtype: torch.dtype, + window_left: int, +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + if page_size == 1: + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + else: + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = prefill_with_paged_kv_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + paged_kv_indptr=indptr, + paged_kv_indices=block_tables, + paged_kv_last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=dtype, + page_size=page_size, + window_left=window_left, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_with_paged_kv_state.reset(token) + + +def create_prefill_state( + *, + device: torch.device, +): + """Create a prefill state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_state( + *, + state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, + cu_seqlens: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + window_left: int, +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + token = prefill_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + kv_indptr=cu_seqlens, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=dtype, + window_left=window_left, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_state.reset(token) + + +def create_decode_state( + *, + device: torch.device, + num_heads: int, + num_kv_heads: int, +): + """Create a decode state.""" + workspace_buffer = get_workspace(device) + num_groups = num_heads // num_kv_heads + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=False, + # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 + use_tensor_cores=num_groups not in [1, 2, 4, 8], + ) + + +def create_decode_state_cuda_graphs( + *, + device: torch.device, + block_tables: torch.Tensor, + block_tables_ptr: torch.Tensor, + last_page_len: torch.Tensor, + num_heads: int, + num_kv_heads: int, +): + """ + Create a decode state for use with CUDA Graphs. `block_tables`, + `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are + therefore stored as part of the state. + """ + workspace_buffer = get_workspace(device) + num_groups = num_heads // num_kv_heads + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=True, + paged_kv_indices_buffer=block_tables, + paged_kv_indptr_buffer=block_tables_ptr, + paged_kv_last_page_len_buffer=last_page_len, + # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 + use_tensor_cores=num_groups not in [1, 2, 4, 8], + ) + + +@contextmanager +def use_decode_state( + *, + state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, + input_lengths: torch.Tensor, + block_tables: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + dtype: torch.dtype, + window_left: int, +): + """ + Context manager to set the active flashinfer decoding state to the given + `state` and parameters. This state will be used by all calls to the + `paged_attention` function while the context manager is active. + """ + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = decode_state.set(state) + + try: + state.begin_forward( + indptr=indptr, + indices=block_tables, + last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + page_size=page_size, + data_type=dtype, + q_data_type=dtype, + window_left=window_left, + ) + yield + finally: + state.end_forward() + if token is not None: + decode_state.reset(token) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/ipex.py b/backends/gaudi/server/text_generation_server/layers/attention/ipex.py new file mode 100644 index 000000000..657c90af4 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/ipex.py @@ -0,0 +1,82 @@ +import intel_extension_for_pytorch as ipex +import torch +from text_generation_server.models.flash_causal_lm import BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen +from typing import Optional + +SUPPORTS_WINDOWING = False +PREFILL_IN_KV_CACHE = False + + +def attention( + q: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale, + window_size_left=-1, + causal=True, + softcap: Optional[float] = None, +): + out = torch.empty_like(q) + + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. + ipex.llm.functional.varlen_attention( + q.contiguous() if q.device.type == "xpu" else q, + key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, + value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_q, + 0.0, + softmax_scale, + False, + causal, + False, + None, + ) + + return out + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + seqlen: Seqlen, + max_s: int, + softcap: Optional[float] = None, +): + out = torch.empty_like(query) + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + seqlen.input_lengths, + BLOCK_SIZE, + max_s, + None, + ) + return out diff --git a/backends/gaudi/server/text_generation_server/layers/attention/rocm.py b/backends/gaudi/server/text_generation_server/layers/attention/rocm.py new file mode 100644 index 000000000..646a763d3 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/rocm.py @@ -0,0 +1,308 @@ +import os +from typing import Optional +import torch +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import ATTENTION +from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils.log import log_master +from loguru import logger + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 + +_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_CUSTOM = 256 + +use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} +ENGINE = "triton" if use_triton else "ck" + +PREFILL_IN_KV_CACHE = False + +use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" +try: + if use_rocm_custom_paged_attn: + from vllm._custom_C import paged_attention_custom +except ImportError as e: + log_master( + logger.info, + f"Custom Paged Attention not available. Complete error: {e}", + ) + use_rocm_custom_paged_attn = False + +try: + import vllm._custom_ops as ops +except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + if ATTENTION == "flashdecoding": + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + seqlen: Seqlen, + max_s: int, + softcap: Optional[float] = None, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + if softcap is not None: + raise RuntimeError("Paged attention doesn't support softcapping") + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + + num_kv_heads = key_cache.shape[1] + gqa_ratio = num_heads // num_kv_heads + use_custom = ( + use_rocm_custom_paged_attn + and (query.dtype == torch.half or query.dtype == torch.bfloat16) + and (head_size == 128 or head_size == 64) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_s <= 32768 + ) + + if not use_custom: + _PARTITION_SIZE = _PARTITION_SIZE_V1V2 + else: + _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM + + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = seqlen.input_lengths + + out = torch.empty_like(query) + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + import vllm._custom_ops as ops + + use_v1 = ( + max_s <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512) + and not use_custom + ) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + if not use_custom: + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + paged_attention_custom( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + ) + + return out + + +if ENGINE != "triton": + try: + import flash_attn_2_cuda + + log_master( + logger.info, + "ROCm: using Flash Attention 2 Composable Kernel implementation.", + ) + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e + + +SUPPORTS_WINDOWING = False +if ENGINE == "ck": + + def attention( + q, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: float = 0.0, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + + out = torch.empty_like(q) + + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. + return flash_attn_2_cuda.varlen_fwd( + q, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + None, + None, + None, + None, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + softcap, + False, + None, + )[0] + +elif ENGINE == "triton": + from .flash_attn_triton import triton_attention + + def attention( + q, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: Optional[float] = None, + ): + if softcap is not None: + raise NotImplementedError("softcap is only available with CK flash attn") + + out = torch.empty_like(q) + + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. + output, _ = triton_attention( + q, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, + causal, + softmax_scale, + ) + return output + +else: + raise RuntimeError(f"Unknown attention engine {ENGINE}") diff --git a/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py b/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py new file mode 100644 index 000000000..b19eafbbe --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py @@ -0,0 +1,97 @@ +import torch +from typing import List + + +AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def pack(imatrix: torch.Tensor, direction: str = "column"): + """ + Packs a 4-bit integer matrix into a packed 32-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of packing, either "column" or "row" + Returns: + qmatrix (torch.Tensor): packed matrix of integers + """ + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + if direction == "column": + imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) + + elif direction == "row": + imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) + + qmatrix = qmatrix.to(torch.int32) + + return qmatrix + + +def unpack(qmatrix: torch.Tensor, direction: str = "column"): + """ + Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. + Args: + qmatrix (torch.Tensor): matrix of packed integers + direction (str): direction of unpacking, either "column" or "row" + Returns: + imatrix (torch.Tensor): matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=qmatrix.device) + + if direction == "column": + imatrix = torch.bitwise_right_shift( + qmatrix[:, :, None], shifts[None, None, :] + ).view(qmatrix.shape[0], -1) + + elif direction == "row": + imatrix = torch.bitwise_right_shift( + qmatrix[:, None, :], shifts[None, :, None] + ).view(-1, qmatrix.shape[-1]) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + return imatrix + + +def apply_order( + imatrix: torch.Tensor, + direction: str = "column", + order: List[int] = AWQ_PACK_ORDER, +): + """ + Applies the order to a 4-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of applying order, either "column" or "row" + order (List[int]): order to apply, default is AWQ_PACK_ORDER + Returns: + imatrix (torch.Tensor): matrix of integers + """ + if direction == "column": + imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) + elif direction == "row": + imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) + + return imatrix + + +def fast_awq_to_gptq(qweight, qzeros): + # awq uses column packing for both weights and zeros + izeros = unpack(qzeros, direction="column") + iweights = unpack(qweight, direction="column") + + # Reverse the order of the iweight and izeros tensors + izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) + iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) + izeros = izeros - 1 + # exllama uses row packing for weights and column packing for zeros + qzeros = pack(izeros, direction="column") + qweight = pack(iweights, direction="row") + + return qweight, qzeros diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py new file mode 100644 index 000000000..391371a55 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -0,0 +1,49 @@ +# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py + +from typing import Optional +import torch +import torch.nn as nn +import awq_inference_engine # with CUDA kernels + + +# class ScaledActivation(nn.Module): +# def __init__(self, module, scales): +# super().__init__() +# self.act = module +# self.scales = nn.Parameter(scales.data) +# +# def forward(self, x): +# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) + + +class WQLinear(nn.Module): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.bias = bias + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features,) + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/bnb.py b/backends/gaudi/server/text_generation_server/layers/bnb.py new file mode 100644 index 000000000..791d9b6d8 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/bnb.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass + +import bitsandbytes as bnb +import torch +from bitsandbytes.nn import Int8Params, Params4bit +from text_generation_server.utils.weights import UnquantizedWeight + + +@dataclass +class BNBWeight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0) + + +class Linear8bitLt(torch.nn.Module): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__() + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) + self.weight.cuda(weight.device) + self.bias = bias + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out + + +@dataclass +class BNBFP4Weight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="fp4") + + +@dataclass +class BNBNF4Weight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="nf4") + + +class Linear4bit(torch.nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, + requires_grad=False, + compress_statistics=True, + quant_type=quant_type, + ) + self.compute_dtype = None + self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out diff --git a/backends/gaudi/server/text_generation_server/layers/conv.py b/backends/gaudi/server/text_generation_server/layers/conv.py new file mode 100644 index 000000000..7fb18ab3f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/conv.py @@ -0,0 +1,41 @@ +from accelerate import init_empty_weights +import torch + + +@classmethod +def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = torch.nn.Parameter(weight) + conv2d.bias = torch.nn.Parameter(bias) + return conv2d + + +@classmethod +def load_conv2d_no_bias( + cls, prefix, weights, in_channels, out_channels, kernel_size, stride +): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = torch.nn.Parameter(weight) + conv2d.bias = None + return conv2d + + +torch.nn.Conv2d.load = load_conv2d +torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias diff --git a/backends/gaudi/server/text_generation_server/layers/eetq.py b/backends/gaudi/server/text_generation_server/layers/eetq.py new file mode 100644 index 000000000..b1e5235a0 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/eetq.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass + +import torch +from EETQ import quant_weights, w8_a16_gemm +from text_generation_server.utils.weights import UnquantizedWeight + + +@dataclass +class EETQWeight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + try: + from text_generation_server.layers.eetq import EETQLinear + + return EETQLinear(self.weight, bias) + except ImportError: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) + + +class EETQLinear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + if weight.dtype != torch.float16: + weight = weight.to(dtype=torch.float16) + weight = torch.t(weight).contiguous().cpu() + weight, scale = quant_weights(weight, torch.int8, False) + + self.weight = weight.cuda(device) + self.scale = scale.cuda(device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = w8_a16_gemm(input, self.weight, self.scale) + output = output + self.bias if self.bias is not None else output + return output diff --git a/backends/gaudi/server/text_generation_server/layers/exl2.py b/backends/gaudi/server/text_generation_server/layers/exl2.py new file mode 100644 index 000000000..a6e07f453 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/exl2.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass +from typing import List, Union + +import torch +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + + +@dataclass +class Exl2Weight(Weight): + """ + Exllama2 exl2 quantized weights. + """ + + q_weight: torch.Tensor + q_scale: torch.Tensor + q_invperm: torch.Tensor + q_scale_max: torch.Tensor + q_groups: torch.Tensor + + def __post_init__(self): + self.q_scale_max /= 256 + self.q_invperm = self.q_invperm.short() + + @property + def device(self) -> torch.device: + return self.q_weight.device + + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.gptq import ExllamaQuantLinear + + return ExllamaQuantLinear(self, bias) + + +class Exl2WeightsLoader(WeightsLoader): + """Loader for exl2-quantized weights.""" + + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + raise RuntimeError("Column-packed weights are not supported for exl") + + def get_weights_col(self, weights: Weights, prefix: str): + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + raise ValueError("get_multi_weights_col is not supported for exl2") + + def get_weights_row(self, weights: Weights, prefix: str): + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py new file mode 100644 index 000000000..61dd51151 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -0,0 +1,285 @@ +import torch + +from dataclasses import dataclass +from typing import Optional, Union, List +from loguru import logger + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import ( + Weight, + WeightsLoader, + UnquantizedWeight, + Weights, +) +from text_generation_server.utils.log import log_master, log_once +import importlib.util + + +FBGEMM_MM_AVAILABLE = False +FBGEMM_DYN_AVAILABLE = False + + +def is_fbgemm_gpu_available(): + try: + return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None + except ModuleNotFoundError: + return False + + +if is_fbgemm_gpu_available(): + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + FBGEMM_MM_AVAILABLE = major == 9 + FBGEMM_DYN_AVAILABLE = major >= 8 +else: + log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") + + +def get_fp8_linear() -> torch.nn.Module: + """ + Return an FP8 linear `Module` that is compatible with the current system. + """ + + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + if major == 8: + from text_generation_server.layers.marlin import GPTQMarlinFP8Linear + + return GPTQMarlinFP8Linear + + # On other systems let Torch decide if the hardware supports FP8. + return Fp8Linear + + +def fp8_quantize( + weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False +): + if FBGEMM_DYN_AVAILABLE and not scalar: + qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( + weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype + ) + return qweight, scale + + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale + + +class HybridFP8UnquantLoader(WeightsLoader): + """Weight loader that loads FP8 and unquantized Torch tensors.""" + + def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + self.activation_scale_ub = activation_scale_ub + self.to_fp8 = to_fp8 + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + scale = scale.reshape(-1).expand(w.shape[0]) + + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + scale = torch.cat(scale, dim=0).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + +@dataclass +class Fp8Weight(Weight): + weight: torch.Tensor + dtype: torch.dtype + weight_scale: Optional[torch.Tensor] = None + activation_scale_ub: Optional[float] = None + + def get_linear(self, bias: torch.Tensor): + if self.weight_scale is None: + return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + # This is not checked by the fbgemm kernels, but they require contiguous + # memory. Can be non-contiguous when we e.g. expand from scalars. + self.weight_scale = self.weight_scale.contiguous() + return get_fp8_linear().from_fp8( + self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + ) + + +class Fp8Linear(torch.nn.Module): + def __init__( + self, + qweight, + scale, + scale_upper_bound, + bias, + dtype, + ) -> None: + super().__init__() + if FBGEMM_MM_AVAILABLE: + log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + + self.dtype = dtype + self.qweight = qweight + self.scale = scale + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) + + self.bias = bias if bias is not None else None + + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) + return cls( + qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + ) + + @classmethod + def from_fp8(cls, weight, scale, input_scale, bias, dtype): + if FBGEMM_DYN_AVAILABLE: + # fbgemm needs float32 scales. + scale = scale.float() + return cls( + qweight=weight, + scale=scale, + scale_upper_bound=input_scale, + bias=bias, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if FBGEMM_MM_AVAILABLE: + qinput, scale = fp8_quantize( + input, scale_upper_bound=self.scale_upper_bound + ) + + y = torch.ops.fbgemm.f8f8bf16_rowwise( + qinput, + self.qweight, + scale, + self.scale, + use_fast_accum=True, + bias=self.bias, + ) + return y.to(self.dtype) + + qinput, scale = fp8_quantize(input, scalar=True) + output, _ = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + return output + + +def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): + scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: + scale = weights.get_sharded(prefix, dim=0, to_dtype=False) + return scale.reshape(-1).expand(shape[0]) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py new file mode 100644 index 000000000..505caa59a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -0,0 +1,440 @@ +import os +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +from loguru import logger +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + + +@dataclass +class GPTQWeight(Weight): + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: Optional[torch.Tensor] + bits: int + groupsize: int + use_awq_kernel: bool + use_exllama: bool + + def __post_init__(self): + if self.scales.dtype == torch.float: + self.scales = self.scales.half() + + @property + def device(self) -> torch.device: + return self.qweight.device + + def get_linear(self, bias: torch.Tensor): + if self.use_awq_kernel: + if SYSTEM == "rocm": + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + try: + from text_generation_server.layers.awq.quantize.qmodule import WQLinear + + return WQLinear( + w_bit=self.bits, + group_size=self.groupsize, + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bias=bias, + ) + except ImportError: + raise NotImplementedError( + "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" + ) + elif self.use_exllama: + try: + from text_generation_server.layers.gptq import ExllamaQuantLinear + except ImportError: + raise NotImplementedError( + "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + return ExllamaQuantLinear(self, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + return QuantLinear( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + bias, + self.bits, + self.groupsize, + ) + + +class GPTQWeightsLoader(WeightsLoader): + """ + Loader for GPTQ- and AWQ-quantized weights. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights(self, weights: Weights, prefix: str): + self._get_gptq_params(weights) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + self._get_gptq_params(weights) + + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + self._get_gptq_params(weights) + + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + + from text_generation_server.layers.gptq import HAS_EXLLAMA + + use_exllama = ( + self.bits == 4 + and HAS_EXLLAMA + and self.quantize == "gptq" + and not self.desc_act + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + self._get_gptq_params(weights) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + + if weights.process_group.size() > 1: + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // self.groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + from text_generation_server.layers.gptq import ( + CAN_EXLLAMA, + HAS_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + if use_exllama and self.groupsize != -1: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + + def _get_gptq_params(self, weights: Weights): + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) + self.quant_method = "gptq" + + +# Needs to be at the end because circular import. +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 + +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + if V2: + from text_generation_server.layers.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, # noqa: F401 + create_exllama_buffers, # noqa: F401 + set_device, # noqa: F401 + ) + + HAS_EXLLAMA = "2" + else: + from text_generation_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 + create_exllama_buffers, # noqa: F401 + set_device, # noqa: F401 + ) + + HAS_EXLLAMA = "1" + + except ImportError: + pass diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py b/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py new file mode 100644 index 000000000..0388ef20b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py @@ -0,0 +1,261 @@ +# https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + +import builtins +import math +import time +from typing import Dict + +import triton + + +class Autotuner(triton.KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + """ + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = ( + prune_configs_by["perf_model"], + prune_configs_by["top_k"], + ) + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current, + ) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench( + kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 + ) + except triton.OutOfResources: + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune( + configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + prune_configs_by, + nearest_power_of_two, + ) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) + n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) + k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) + block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) + block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) + group_size_m = config.kwargs["GROUP_SIZE_M"] + + if ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) in used: + continue + + used.add( + ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) + ) + yield triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py new file mode 100644 index 000000000..f27666b77 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py @@ -0,0 +1,134 @@ +from text_generation_server.layers.gptq import GPTQWeight +import torch +from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +none_tensor = torch.empty((1, 1), device="meta") + + +def ext_make_q4(qweight, qzeros, scales, g_idx, device): + """Construct Q4Matrix, return handle""" + return make_q4( + qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device + ) + + +def ext_q4_matmul(x, q4, q4_width): + """Matrix multiplication, returns x @ q4""" + outshape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device) + + q4_matmul(x, q4, output) + + return output.view(outshape) + + +MAX_DQ = 1 +MAX_INNER = 1 +ACT_ORDER = False +DEVICE = None + +TEMP_STATE = None +TEMP_DQ = None + + +def set_device(device): + global DEVICE + DEVICE = device + + +def create_exllama_buffers(max_total_tokens: int): + global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ + + assert DEVICE is not None, "call set_device first" + + if not ACT_ORDER: + max_total_tokens = 1 + + # This temp_state buffer is required to reorder X in the act-order case. + temp_state = torch.zeros( + (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE + ) + temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) + + # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + prepare_buffers(DEVICE, temp_state, temp_dq) + + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + TEMP_STATE, TEMP_DQ = temp_state, temp_dq + + +class Ex4bitLinear(torch.nn.Module): + """Linear layer implementation with per-group 4-bit quantization of the weights""" + + def __init__(self, weight: GPTQWeight, bias): + super().__init__() + global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE + assert weight.bits == 4 + + self.device = weight.qweight.device + self.qweight = weight.qweight + self.qzeros = weight.qzeros + self.scales = weight.scales + self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None + self.bias = bias if bias is not None else None + + if self.g_idx is not None and ( + (self.g_idx == 0).all() + or torch.equal( + weight.g_idx.cpu(), + torch.tensor( + [i // weight.groupsize for i in range(weight.g_idx.shape[0])], + dtype=torch.int32, + ), + ) + ): + self.empty_g_idx = True + self.g_idx = None + + assert self.device.type == "cuda" + assert self.device.index is not None + + self.q4 = ext_make_q4( + self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index + ) + + self.height = weight.qweight.shape[0] * 8 + self.width = weight.qweight.shape[1] + + # Infer groupsize from height of qzeros + self.groupsize = None + if self.qzeros.shape[0] > 1: + self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) + + if self.groupsize is not None: + assert weight.groupsize == self.groupsize + + # Handle act-order matrix + if self.g_idx is not None: + if self.groupsize is None: + raise ValueError("Found group index but no groupsize. What do?") + self.act_order = True + else: + self.act_order = False + + DEVICE = self.qweight.device + + MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8) + + if self.act_order: + MAX_INNER = max(MAX_INNER, self.height, self.width) + + ACT_ORDER = True + + def forward(self, x): + out = ext_q4_matmul(x, self.q4, self.width) + + if self.bias is not None: + out.add_(self.bias) + return out diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py new file mode 100644 index 000000000..920a6adf4 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py @@ -0,0 +1,267 @@ +# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 + +from dataclasses import dataclass +from typing import Optional +import torch +import torch.nn as nn + +from loguru import logger + +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.utils.log import log_master + +try: + from exllamav2.ext import exllamav2_ext + + make_q_matrix = exllamav2_ext.make_q_matrix + gemm_half_q_half = exllamav2_ext.gemm_half_q_half +except ImportError: + log_master(logger.warning, "exllamav2_kernels not installed.") + raise + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +none_tensor = torch.empty((1, 1), device="meta") + + +@dataclass +class _ExtraTensors: + """Additional generated quantizer tensors.""" + + q_group_map: Optional[torch.Tensor] = None + q_invperm: Optional[torch.Tensor] = None + q_perm: Optional[torch.Tensor] = None + + +def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): + """Matrix multiplication, returns x @ q4""" + output_shape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) + gemm_half_q_half(x, q_handle, output, force_cuda) + return output.view(output_shape) + + +def make_group_map(q_groups: torch.Tensor, num_qrows: int): + gr = q_groups.tolist() + group_map = [] + num_groups = len(gr) // 2 + + for i in range(num_groups): + bits = gr[i * 2] + if i < num_groups - 1: + qrows = gr[i * 2 + 3] - gr[i * 2 + 1] + else: + qrows = num_qrows - gr[i * 2 + 1] + rows = qrows * 32 // bits + for j in range(rows): + group_map += [i] + group_map += [rows - j] + + return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) + + +# Create Q matrix + + +def ext_make_q_matrix( + w: Exl2Weight | GPTQWeight, + extra: _ExtraTensors, + temp_dq, + key: Optional[str] = None, +): + """ + Create Q matrix + """ + # max_dq_size = 512*(1024**2) + # max_dq_rows = max_dq_size // out_features[0] + max_dq_rows = 0 + + # EXL2 + if isinstance(w, Exl2Weight): + extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) + extra.q_perm = torch.argsort(w.q_invperm).short() + + return make_q_matrix( + w.q_weight, + extra.q_perm, + w.q_invperm, + w.q_scale, + w.q_scale_max, + w.q_groups, + extra.q_group_map, + none_tensor, # zeros + none_tensor, # scales + none_tensor, # g_idx + none_tensor, # bias + temp_dq, + max_dq_rows, + ) + # GPTQ + elif isinstance(w, GPTQWeight): + if w.scales.dtype == torch.float: + w.scales = w.scales.half() + + # GPTQ with g_idx (act_order) + if w.g_idx is not None and not (w.g_idx == 0).all().item(): + extra.q_perm = torch.empty( + (w.qweight.shape[0] * 8,), + dtype=torch.short, + device=w.qweight.device, + ) + extra.q_invperm = torch.empty_like(extra.q_perm) + # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. + return make_q_matrix( + w.qweight, + extra.q_perm, + extra.q_invperm, + none_tensor, # q_scale + none_tensor, # q_scale_max + none_tensor, # q_groups + none_tensor, # q_group_map + w.qzeros, + w.scales, + w.g_idx.cpu(), + none_tensor, # bias + temp_dq, + max_dq_rows, + ) + # GPTQ without g_idx + else: + return make_q_matrix( + w.qweight, + none_tensor, # q_perm + none_tensor, # q_invperm + none_tensor, # q_scale + none_tensor, # q_scale_max + none_tensor, # q_groups + none_tensor, # q_group_map + w.qzeros, + w.scales, + none_tensor, # g_idx + none_tensor, # bias + temp_dq, + max_dq_rows, + ) + else: + RuntimeError("Cannot create handle") + + +DEVICE = None +LAYERS = [] + + +def set_device(device): + global DEVICE + DEVICE = device + + +def create_exllama_buffers(max_total_tokens: int): + global LAYERS, DEVICE + + # No need to initialize scratch space if there are no layers + # that use ExLLamav2. + if len(LAYERS) == 0: + return + + # Find the size of the scratch space. + scratch_bytes = max( + layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1) + for layer in LAYERS + ) + temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) + + for layer in LAYERS: + layer.post_init(temp_dq) + + +class QuantLinear(nn.Module): + QUANT_TYPE = "exllamav2" + + """Linear layer implementation with per-group 4-bit quantization of the weights""" + + def __init__( + self, + weight: Exl2Weight | GPTQWeight, + bias: torch.Tensor, + ): + super().__init__() + + self.q_handle = None + self.q_tensors = weight + self.extra_tensors = _ExtraTensors() + + if isinstance(weight, Exl2Weight): + self.infeatures = weight.q_invperm.shape[0] + self.outfeatures = weight.q_weight.shape[1] + elif isinstance(weight, GPTQWeight): + if weight.bits != 4: + raise ValueError( + f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." + ) + + self.infeatures = weight.qweight.shape[0] // weight.bits * 32 + self.outfeatures = weight.qweight.shape[1] + + self.padding = -self.outfeatures % 32 + self.outfeatures = self.outfeatures + self.padding + + self.device = weight.device + self.bias = bias if bias is not None else None + + global LAYERS + LAYERS.append(self) + + def post_init(self, temp_dq): + device = self.q_tensors.device + assert device.type == "cuda" + assert device.index is not None + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) + + # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, + # and `Memory access fault by GPU node-2` will EAT you. + self.temp_dq = temp_dq + self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) + + def forward(self, x, force_cuda=False): + output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) + + if self.bias is not None: + output.add_(self.bias) + return output + + def temp_dq_size(self): + return self.infeatures * self.outfeatures * 2 + 128 + + def temp_fwd_size(self, max_input_len, max_batch_size): + return self.outfeatures * max_input_len * max_batch_size * 4 + 128 + + def scratch_space_fixed(self, max_input_len, max_batch_size): + return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) + + +class ExLlamaV2DeviceTensors: + + device_idx: int + scratch_bytes: int + scratch_idx: int + scratch: torch.tensor = None + + def __init__(self, device, scratch_bytes): + self.device = device + self.scratch_bytes = scratch_bytes + + def prepare(self): + self.scratch = torch.empty( + (self.scratch_bytes // 2,), dtype=torch.half, device=self.device + ) + + def get_scratch_slice(self, size_bytes): + + if self.scratch is None: + self.prepare() + + size_bytes = ((size_bytes + 127) // 128) * 128 + size_half = size_bytes // 2 + scratch_slice = self.scratch.narrow(0, 0, size_half) + return scratch_slice diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py b/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py new file mode 100644 index 000000000..736c357b0 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py @@ -0,0 +1,359 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_fwd + +import triton +import triton.language as tl +from . import custom_autotune + + +# code based https://github.com/fpgaminer/GPTQ-triton +@custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load( + scales_ptrs + g_idx[:, None] * stride_scales + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + g_idx[:, None] * stride_zeros + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) & maxq # eventually avoid overflow + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty( + (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 + ) + + def grid(META): + return ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + + matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + return output + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py new file mode 100644 index 000000000..b0086ea08 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py @@ -0,0 +1,1017 @@ +import time +import torch.nn as nn +import math +import json +import os +import torch +import transformers + +from texttable import Texttable +from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer +from huggingface_hub import HfApi +from accelerate import init_empty_weights +from text_generation_server.utils import initialize_torch_distributed, Weights +from text_generation_server.utils.hub import weight_files +from text_generation_server.layers.gptq.quant_linear import QuantLinear +from loguru import logger +from typing import Optional +from text_generation_server.layers.gptq.utils import torch_snr_error + +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + +DEV = torch.device("cuda:0") + + +class Quantizer(nn.Module): + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) + + def configure( + self, + bits, + perchannel=False, + sym=True, + mse=False, + norm=2.4, + grid=100, + maxshrink=0.8, + trits=False, + ): + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + self.scale = torch.zeros_like(self.scale) + + def _quantize(self, x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = self._quantize( + x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq + ) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return self._quantize(x, self.scale, self.zero, self.maxq) + + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +class GPTQ: + def __init__(self, layer, observe=False): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.quantizer = Quantizer() + self.observe = observe + + def add_batch(self, inp, out): + # Hessian H = 2 X XT + λ I + if self.observe: + self.inp1 = inp + self.out1 = out + else: + self.inp1 = None + self.out1 = None + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def print_loss(self, name, q_weight, weight_error, timecost): + table = Texttable() + length = 28 + name = ( + (name + " " * (length - len(name))) + if len(name) <= length + else name[:length] + ) + + table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) + + # assign weight + self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) + + if self.inp1 is not None: + # quantize input to int8 + quantizer = Quantizer() + quantizer.configure(8, perchannel=False, sym=True, mse=False) + quantizer.find_params(self.inp1) + q_in = quantizer.quantize(self.inp1).type(torch.float16) + q_out = self.layer(q_in) + + # get kinds of SNR + q_SNR = torch_snr_error(q_out, self.out1).item() + fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() + else: + q_SNR = "-" + fp_SNR = "-" + + table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) + print(table.draw().split("\n")[-2]) + + def fasterquant( + self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name="" + ): + self.layer.to(self.dev) + + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + if not self.observe: + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if act_order: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + try: + H = torch.linalg.cholesky(H, upper=True) + except Exception: + # Addition because Falcon fails on h_to_4h + H = torch.linalg.cholesky( + H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True + ) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params( + W[:, (i1 + i) : (i1 + i + groupsize)], weight=True + ) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = self.quantizer.quantize(w.unsqueeze(1)).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + error = torch.sum(Losses).item() + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if act_order: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + + self.print_loss( + name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) + ) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + return scale, zero, g_idx, error + + def free(self): + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() + + +def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + use_auth_token=False, + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + use_auth_token=False, + ) + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]["text"], return_tensors="pt") + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + ) + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") + valenc = valenc.input_ids[:, : (256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders( + name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False +): + if "wikitext2" in name: + return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code) + if "ptb" in name: + if "new" in name: + return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code) + return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code) + if "c4" in name: + if "new" in name: + return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code) + return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code) + + +def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""): + # Skip last lm_head linear + # Need isintance Falcon is inheriting Linear. + if isinstance(module, layers) and "lm_head" not in name: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res + + +@torch.no_grad() +def sequential( + model, + dataloader, + dev, + nsamples, + bits, + groupsize, + *, + hooks, + percdamp=0.01, + sym: bool = False, + act_order: bool = False, +): + print("Starting ...") + + use_cache = model.config.use_cache + model.config.use_cache = False + try: + layers = model.model.layers + prefix = "model.layers" + except Exception: + layers = model.transformer.h + prefix = "transformer.h" + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + + cache = {"i": 0} + extra = {} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache["i"]] = inp + cache["i"] += 1 + extra.update(kwargs.copy()) + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].cuda()) + except ValueError: + pass + layers[0] = layers[0].module + + # layers[0] = layers[0].cpu() + # model.model.embed_tokens = model.model.embed_tokens.cpu() + # model.model.norm = model.model.norm.cpu() + torch.cuda.empty_cache() + for hook in hooks: + hook.remove() + + outs = torch.zeros_like(inps) + + extra = { + k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items() + } + + print("Ready.") + + quantizers = {} + for i in range(len(layers)): + print(f"Quantizing layer {i+1}/{len(layers)}..") + print("+------------------+--------------+------------+-----------+-------+") + print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") + print("+==================+==============+============+===========+=======+") + + layer = layers[i] + layer.load() + full = find_layers(layer) + sequential = [list(full.keys())] + + for names in sequential: + subset = {n: full[n] for n in names} + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer.configure( + bits, perchannel=True, sym=sym, mse=False + ) + pass + + def add_batch(name): + nonlocal gptq + + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] + for h in handles: + h.remove() + + for name in subset: + scale, zero, g_idx, error = gptq[name].fasterquant( + percdamp=percdamp, + groupsize=groupsize, + act_order=act_order, + name=name, + ) + quantizers[f"{prefix}.{i}.{name}"] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + bits, + groupsize, + ) + + gptq[name].free() + + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] + + layer.unload() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + print("+------------------+--------------+------------+-----------+-------+") + print("\n") + + model.config.use_cache = use_cache + + return quantizers + + +def make_quant_linear(module, names, bits, groupsize, name=""): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + "." + attr if name != "" else attr + if name1 in names: + delattr(module, attr) + setattr( + module, + attr, + QuantLinear.new( + bits, + groupsize, + tmp.in_features, + tmp.out_features, + tmp.bias is not None, + ), + ) + for name1, child in module.named_children(): + make_quant_linear( + child, names, bits, groupsize, name + "." + name1 if name != "" else name1 + ) + + +# TODO: perform packing on GPU +def pack(model, quantizers, bits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant_linear(model, quantizers, bits, groupsize) + qlayers = find_layers(model, (QuantLinear,)) + print("Packing ...") + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print("Done.") + return model + + +def setdeepattr(module, full_name, tensor): + current = module + tokens = full_name.split(".") + for token in tokens[:-1]: + current = getattr(current, token) + setattr(current, tokens[-1], tensor) + + +def getdeepattr(module, full_name): + current = module + tokens = full_name.split(".") + for token in tokens: + current = getattr(current, token) + return current + + +def load_weights_pre_hook(module_name, weights, recursive=False): + def inner(module, args): + print(f"Pre hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + + for local_param in local_params: + current_tensor = getdeepattr(module, local_param) + if current_tensor.device == torch.device("meta"): + # print(f"Loading {local_param}") + if module_name: + tensor_name = f"{module_name}.{local_param}" + else: + tensor_name = local_param + tensor = weights.get_tensor(tensor_name) + setdeepattr(module, local_param, nn.Parameter(tensor)) + else: + tensor = current_tensor.to(device=torch.device("cuda:0")) + if current_tensor.requires_grad: + tensor = nn.Parameter(tensor) + setdeepattr(module, local_param, tensor) + + return inner + + +def load_weights_post_hook(module_name, weights, recursive=False): + def inner(module, args, output): + print(f"Post hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for local_param in local_params: + # print(f"Unloading {local_param}") + current_tensor = getdeepattr(module, local_param) + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cpu"))), + ) + return output + + return inner + + +def quantize( + model_id: str, + bits: int, + groupsize: int, + output_dir: str, + revision: str, + trust_remote_code: bool, + upload_to_model_id: Optional[str], + percdamp: float, + act_order: bool, + sym: bool, +): + print("loading model") + config = AutoConfig.from_pretrained( + model_id, + trust_remote_code=trust_remote_code, + ) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config( + config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) + model = model.eval() + + print("LOADED model") + files = weight_files(model_id, revision, extension=".safetensors") + process_group, _, _ = initialize_torch_distributed() + weights = Weights( + files, + device=torch.device("cuda:0"), + dtype=torch.float16, + process_group=process_group, + aliases={"embed_tokens.weight": ["lm_head.weight"]}, + weights_loader=DefaultWeightsLoader(UnquantizedWeight), + ) + hooks = [] + for name, module in model.named_modules(): + + def load(module, name): + def _load(): + load_weights_pre_hook(name, weights, recursive=True)(module, None) + + return _load + + def unload(module, name): + def _unload(): + load_weights_post_hook(name, weights, recursive=True)( + module, None, None + ) + + return _unload + + module.load = load(module, name) + module.unload = unload(module, name) + hooks.append( + module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) + ) + hooks.append( + module.register_forward_hook(load_weights_post_hook(name, weights)) + ) + model.seqlen = 2048 + + dataset = "wikitext2" + nsamples = 128 + seed = None + + dataloader, testloader = get_loaders( + dataset, + nsamples=nsamples, + seed=seed, + model_id=model_id, + seqlen=model.seqlen, + trust_remote_code=trust_remote_code, + ) + + tick = time.time() + quantizers = sequential( + model, + dataloader, + DEV, + nsamples, + bits, + groupsize, + percdamp=percdamp, + act_order=act_order, + hooks=hooks, + sym=sym, + ) + print(time.time() - tick) + + pack(model, quantizers, bits, groupsize) + from safetensors.torch import save_file + from transformers.modeling_utils import shard_checkpoint + + state_dict = model.state_dict() + state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} + + max_shard_size = "10GB" + shards, index = shard_checkpoint( + state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + ) + os.makedirs(output_dir, exist_ok=True) + for shard_file, shard in shards.items(): + save_file( + shard, + os.path.join(output_dir, shard_file), + metadata={ + "format": "pt", + "quantized": "gptq", + "origin": "text-generation-inference", + }, + ) + if index is None: + path_to_weights = os.path.join(output_dir, "model.safetensors") + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = "model.safetensors.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + config.quantization_config = { + "bits": bits, + "group_size": groupsize, + "damp_percent": percdamp, + "desc_act": act_order, + "static_groups": False, + "sym": sym, + "quant_method": "gptq", + } + config.save_pretrained(output_dir) + logger.info("Saved config") + logger.info("Saving tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + tokenizer.save_pretrained(output_dir) + logger.info("Saved tokenizer") + + if upload_to_model_id: + api = HfApi() + + api.upload_folder( + folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model" + ) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/utils.py b/backends/gaudi/server/text_generation_server/layers/gptq/utils.py new file mode 100644 index 000000000..cbc0f391f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/utils.py @@ -0,0 +1,56 @@ +import torch + + +# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py +def torch_snr_error( + y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean" +) -> torch.Tensor: + """ + Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calcualted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + torch.Tensor: _description_ + """ + if y_pred.shape != y_real.shape: + raise ValueError( + f"Can not compute snr loss for tensors with different shape. " + f"({y_pred.shape} and {y_real.shape})" + ) + reduction = str(reduction).lower() + + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(0) + y_real = y_real.unsqueeze(0) + + y_pred = y_pred.flatten(start_dim=1) + y_real = y_real.flatten(start_dim=1) + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + + if reduction == "mean": + return torch.mean(snr) + elif reduction == "sum": + return torch.sum(snr) + elif reduction == "none": + return snr + else: + raise ValueError("Unsupported reduction method.") diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py new file mode 100644 index 000000000..ce5289f93 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -0,0 +1,184 @@ +import torch +from torch import nn +from accelerate import init_empty_weights +from text_generation_server.utils.import_utils import ( + SYSTEM, +) + + +# Monkey patching +@classmethod +def load_layer_norm(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = torch.nn.Parameter(weight) + ln.bias = torch.nn.Parameter(bias) + return ln + + +@classmethod +def load_layer_norm_no_bias(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = torch.nn.Parameter(weight) + ln.bias = None + return ln + + +torch.nn.LayerNorm.load = load_layer_norm +torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias + +if SYSTEM == "cuda": + import dropout_layer_norm + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": + from vllm._C import ops + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super().forward(hidden_states), residual + +elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + out = ipex.llm.functional.add_layer_norm( + residual, + hidden_states, + self.weight, + self.bias, + self.eps, + residual is not None, + ) + return out, residual if residual is not None else hidden_states + + +class FastRMSNorm(nn.Module): + def __init__(self, weight: torch.Tensor, eps: float): + super().__init__() + + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + @classmethod + def load(cls, prefix, weights, eps=1e-6): + weight = weights.get_tensor(f"{prefix}.weight") + return cls(weight, eps) + + def forward(self, hidden_states, residual=None): + if SYSTEM == "ipex": + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + residual is not None, + ) + return out, residual if residual is not None else hidden_states + elif hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + elif SYSTEM == "cuda": + # faster post attention rms norm + ( + normed_hidden_states, + res, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + elif SYSTEM == "rocm": + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + out = torch.empty_like(hidden_states) + ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) diff --git a/backends/gaudi/server/text_generation_server/layers/linear.py b/backends/gaudi/server/text_generation_server/layers/linear.py new file mode 100644 index 000000000..08306d579 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/linear.py @@ -0,0 +1,126 @@ +import torch +from text_generation_server.utils.import_utils import SYSTEM +from torch.nn import functional as F +import os + +if SYSTEM == "rocm": + ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( + "true", + "1", + ) + + if ROCM_USE_SKINNY_GEMM: + try: + from vllm import _custom_C + except Exception as e: + raise ImportError( + f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" + ) + + +class FastLinear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(weight, requires_grad=False) + if bias is not None: + self.bias = torch.nn.Parameter(bias, requires_grad=False) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + +class FastLinearROCm(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(weight) + if bias is not None: + self.bias = torch.nn.Parameter(bias) + else: + self.bias = None + + self.cu_count = torch.cuda.get_device_properties( + device="cuda" + ).multi_processor_count + self.use_skinny_gemm = ( + ROCM_USE_SKINNY_GEMM + and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName + ) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + weight = self.weight + bias = self.bias + + if ( + self.use_skinny_gemm + and inp.dtype == torch.float16 + and inp.shape[-1] % 8 == 0 + ): + batched = False + inp_shape = inp.shape + + if inp.dim() == 3: + inp = inp.view(-1, inp_shape[-1]) + batched = True + + m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] + if m > 8 and n <= 4: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) + _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) + _custom_C.LLMM1(weight, inp, out, 4) + else: + out = F.linear(inp, weight) + + if batched: + out.view(*inp_shape[:-1], out.shape[-1]) + + if bias is not None: + out = out + bias + return out + return F.linear(inp, self.weight, self.bias) + + +def get_linear(weight, bias): + # Weights that are loaded through methods that are not + # quantization-aware are still bare tensors. We may want + # to change this in the future. + if isinstance(weight, torch.Tensor): + if SYSTEM == "rocm": + return FastLinearROCm(weight, bias) + else: + return FastLinear(weight, bias) + + return weight.get_linear(bias) diff --git a/backends/gaudi/server/text_generation_server/layers/lora.py b/backends/gaudi/server/text_generation_server/layers/lora.py new file mode 100644 index 000000000..a4537b55b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/lora.py @@ -0,0 +1,279 @@ +from typing import TYPE_CHECKING, Optional, List + +import torch +import torch.distributed +from torch import nn +from torch.distributed import ProcessGroup + +from text_generation_server.utils.sgmv import ( + add_lora_a_bgmv, + add_lora_b_bgmv, + has_sgmv, + lora_a_sgmv_cutlass, + lora_b_sgmv_cutlass, + orient_for_rank, +) + +if TYPE_CHECKING: + from text_generation_server.adapters import AdapterBatchData + from text_generation_server.adapters.lora import BatchLoraWeights + + +class LoraLinear(nn.Module): + def __init__( + self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup + ): + super().__init__() + self.base_layer = base_layer + self.layer_id = layer_id + self.process_group = process_group + + def forward_layer_type( + self, + result: torch.Tensor, + input: torch.Tensor, + adapter_data: "AdapterBatchData", + layer_type: str, + start_idx: int, + end_idx: int, + ) -> torch.Tensor: + if adapter_data is None: + return result + data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) + + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + # In tensor-parallel configurations, each GPU processes a specific segment of the output. + # The 'result' tensor represents the full output, which can vary in size based on + # the layer type (e.g., attention vs. feed-forward layers). We define the current + # segment using start_idx and end_idx. If the segment size doesn't match this GPU's + # slice of 'result', we create a zero tensor of the correct size for LoRA computation. + # This approach ensures accurate LoRA application across various layer sizes and + # configurations, adapting to different model architectures and parallelization strategies. + # + # Example scenarios where this is necessary: + # 1. The adapter's size doesn't evenly divide across GPUs. + # 2. We're processing the last segment which might be smaller. + # 3. Different projection layers (q, k, v) have different sizes. + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if lora_a_ptr is None or lora_b_ptr is None: + raise ValueError("LoRA data is missing") + + if data.use_sgmv: + # Use SGMV for prefill + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + v = torch.zeros( + (input.size(0), r), dtype=input.dtype, device=input.device + ) + # TODO: error with [-1, 0], but not [0, -1] + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) + + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + else: + for adapter_index in adapter_data.meta.adapter_set: + if data is not None and data.has_adapter(adapter_index): + adapter_mask = ( + (adapter_data.meta.adapter_indices == adapter_index) + .to(input.dtype) + .view(-1, 1) + ) + layer_result = self.forward_lora( + input, data, adapter_index, adapter_mask + ) + result[:, start_idx:end_idx] += layer_result + + return result + + def forward_lora( + self, + input: torch.Tensor, + data: "BatchLoraWeights", + adapter_index: int, + adapter_mask: torch.Tensor, + ) -> torch.Tensor: + lora_a = data.lora_a[adapter_index][self.layer_id, :, :] + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] + + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = input @ lora_a + if self.process_group.size() > 1: + a_out = self.collect_lora_a(a_out) + + result = (a_out @ lora_b) * adapter_mask + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Implemented in subclasses") + + +class TensorParallelMultiAdapterLinear(LoraLinear): + def __init__( + self, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): + super().__init__(base_layer, layer_id, process_group) + self.layer_names = layer_names + self.sizes = sizes + + @classmethod + def load( + cls, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): + return TensorParallelMultiAdapterLinear( + base_layer, layer_id, layer_names, sizes, process_group + ) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + # noop if no layer names are provided (e.g. for models without adapters) + if self.layer_names is None: + return result + + # handle models like Bloom that have inputs of shape + # (batch_size, sequence_length, hidden_size) + # we need to reshape them to (batch_size * sequence_length, hidden_size) + # for the LoRA computation, then reshape back + prev_shape = result.shape + is_3d = len(input.shape) >= 3 + if is_3d: + input = input.reshape(-1, input.shape[-1]) + result = result.reshape(-1, result.shape[-1]) + + offset = 0 + for i, layer_name in enumerate(self.layer_names): + start_idx = offset // self.process_group.size() + # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple + # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It + # ensures correct slicing of the result tensor, accommodating variations like grouped-query + # attention where k_proj and v_proj differ from q_proj. This allows precise application of + # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the + # different projection sizes across layers and model architectures. + if self.sizes is not None: + offset += self.sizes[i] + end_idx = offset // self.process_group.size() + else: + end_idx = result.shape[1] + + result = self.forward_layer_type( + result, input, adapter_data, layer_name, start_idx, end_idx + ) + + if is_3d: + result = result.reshape(prev_shape) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise. + # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-gather for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + gathered_tensors = [ + torch.empty_like(a_out) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(gathered_tensors, a_out) + return torch.cat(gathered_tensors, dim=1) + + +class TensorParallelAdapterRowLinear(LoraLinear): + def __init__(self, base_layer, layer_id, layer_name, process_group): + super().__init__(base_layer, layer_id, process_group) + self.layer_name = layer_name + + @classmethod + def load(cls, base_layer, layer_id, layer_name, process_group): + return cls(base_layer, layer_id, layer_name, process_group) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + if self.layer_name is None: + return result + + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 + stride = result.shape[-1] // self.process_group.size() + start_idx = self.process_group.rank() * stride + end_idx = (self.process_group.rank() + 1) * stride + + self.forward_layer_type( + result, input, adapter_data, self.layer_name, start_idx, end_idx + ) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. + # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + torch.distributed.all_reduce(a_out, group=self.process_group) + return a_out diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py b/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py new file mode 100644 index 000000000..3ff3ed58f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py @@ -0,0 +1,15 @@ +from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear +from text_generation_server.layers.marlin.gptq import ( + GPTQMarlinWeightsLoader, + can_use_gptq_marlin, + repack_gptq_for_marlin, +) +from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader + +__all__ = [ + "GPTQMarlinFP8Linear", + "GPTQMarlinWeightsLoader", + "MarlinWeightsLoader", + "can_use_gptq_marlin", + "repack_gptq_for_marlin", +] diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py new file mode 100644 index 000000000..fe55a58a3 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py @@ -0,0 +1,140 @@ +from typing import Optional + +import torch +import torch.nn as nn +from loguru import logger +from text_generation_server.layers.fp8 import fp8_quantize +from text_generation_server.layers.marlin.gptq import _check_valid_shape +from text_generation_server.layers.marlin.util import ( + _check_marlin_kernels, + permute_scales, +) +from text_generation_server.utils.log import log_once + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + +MARLIN_TILE_SIZE = 16 + + +class GPTQMarlinFP8Linear(nn.Module): + """ + FP8 GPTQ-Marlin linear layer. + """ + + def __init__( + self, + qweight: torch.Tensor, + scales: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + + scales = scales.unsqueeze(0) + if scales.shape[1] == 1: + out_features, in_features = qweight.shape + scales = scales.repeat(1, out_features) + qweight, scales = repack_fp8_for_marlin(qweight, scales) + + in_features = qweight.shape[0] * MARLIN_TILE_SIZE + out_features = scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.qweight = qweight + self.scales = scales + self.bias = bias if bias is not None else None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=qweight.device + ) + + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scales = fp8_quantize(weight) + return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) + + @classmethod + def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + return cls(qweight=weight, scales=scale.to(dtype), bias=bias) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.fp8_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.workspace, + 8, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements). + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + + if fp8_tensor.shape[0] % 4 != 0: + raise ValueError( + f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" + ) + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = torch.zeros( + fp8_tensor.shape[0] // 4, + fp8_tensor.shape[1], + dtype=torch.int32, + device=fp8_tensor.device, + ) + + for i in range(4): + packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) + + return packed + + +def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): + """ + Repack FP8 tensor for GPTQ-Marlin. + """ + + out_features, in_features = weight.shape + + # Torch linear layers weights with shape [out_features, in_features], + # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], + # so transpose before packing. + qweight = pack_fp8_as_int32(weight.t()) + + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, 8 + ) + + scales = permute_scales(scales) + + return repacked, scales diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py new file mode 100644 index 000000000..0a785d944 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py @@ -0,0 +1,464 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy +import torch +import torch.nn as nn +from loguru import logger +from text_generation_server.layers.marlin.util import ( + _check_marlin_kernels, + marlin_zero_points, + permute_scales, + unpack_cols, +) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +GPTQ_MARLIN_BITS = [4, 8] +GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] +MARLIN_TILE_SIZE = 16 + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: + return ( + SYSTEM == "cuda" + and marlin_kernels is not None + and has_sm_8_0 + and quantize in {"awq", "gptq"} + and quant_method in {"awq", "gptq"} + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + # We only suppord asymmetric quantization for AWQ. + and (sym or quant_method == "awq") + ) + + +class GPTQMarlinWeightsLoader(WeightsLoader): + """ + Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + if not self.sym: + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + if not self.sym: + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + if self.desc_act or self.groupsize == -1: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + def _get_gptq_params(self, weights: Weights): + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) + self.quant_method = "gptq" + + +@dataclass +class GPTQMarlinWeight(Weight): + """ + Repacked GPTQ Marlin weights. + """ + + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: torch.Tensor + perm: torch.Tensor + bits: int + is_full_k: bool + + def __post_init__(self): + assert self.qweight.dtype == torch.int32 + assert self.scales.dtype == torch.float16 + assert self.g_idx.dtype == torch.int32 + assert self.perm.dtype == torch.int32 + + def get_linear(self, bias: torch.Tensor): + return GPTQMarlinLinear( + weight=self, + bias=bias, + ) + + +def repack_gptq_for_marlin( + *, + qweight: torch.Tensor, + qzeros: Optional[torch.Tensor], + scales: torch.Tensor, + g_idx: Optional[torch.Tensor], + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + sym: bool, + sharded_infeatures: bool, +) -> GPTQMarlinWeight: + """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" + _check_marlin_kernels() + assert marlin_kernels is not None + + if bits not in GPTQ_MARLIN_BITS: + supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) + raise RuntimeError( + f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" + ) + + if groupsize not in GPTQ_MARLIN_GROUP_SIZES: + supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) + raise RuntimeError( + f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" + ) + if not (sym or quant_method == "awq"): + raise RuntimeError( + "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." + ) + + log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") + + weights_per_int = 32 // bits + in_features = qweight.shape[0] + out_features = qweight.shape[1] + + # AWQ uses column packing, GPTQ uses row packing + if quant_method == "awq": + out_features *= weights_per_int + else: + in_features *= weights_per_int + + if in_features % groupsize != 0: + raise ValueError( + f"Number of input features ({in_features}) not divisible by group size ({groupsize})" + ) + + if g_idx is not None and desc_act and groupsize != -1: + perm = torch.argsort(g_idx).to(torch.int) + g_idx = g_idx[perm] + else: + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) + + if quant_method == "awq": + repacked = marlin_kernels.awq_marlin_repack( + qweight, in_features, out_features, bits + ) + if qzeros is not None: + qzeros = awq_to_marlin_zero_points( + qzeros, + in_features // groupsize, + out_features, + bits, + ) + + else: + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, bits + ) + + if qzeros is None: + qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) + + scales = permute_scales(scales) + + is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) + + return GPTQMarlinWeight( + qweight=repacked, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + perm=perm, + bits=bits, + is_full_k=is_full_k, + ) + + +class GPTQMarlinLinear(nn.Module): + """ + Linear layer for GPTQ weights that were converted for the GPTQ-Marlin + kernels. + """ + + def __init__( + self, + *, + weight: GPTQMarlinWeight, + bias: Optional[torch.Tensor], + ): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE + out_features = weight.scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.bits = weight.bits + self.is_full_k = weight.is_full_k + + self.qweight = weight.qweight + self.qzeros = weight.qzeros + self.scales = weight.scales + self.g_idx = weight.g_idx + self.perm = weight.perm + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.gptq_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.perm, + self.workspace, + self.bits, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + self.is_full_k, + self.qzeros.numel() > 0, + True, + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def _check_valid_shape(in_features: int, out_features: int): + if (in_features % 128 != 0 or out_features % 64 != 0) and ( + in_features % 64 != 0 or out_features % 128 != 0 + ): + raise ValueError( + f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." + " The shape elements must be divisible by (128, 64) or (64, 128)." + ) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py new file mode 100644 index 000000000..89ebaca62 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py @@ -0,0 +1,346 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from text_generation_server.layers.marlin.util import _check_marlin_kernels +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + +class MarlinWeightsLoader(WeightsLoader): + """Loader for Marlin-quantized weights.""" + + def __init__(self, *, bits: int, is_marlin_24: bool): + self.bits = bits + self.is_marlin_24 = is_marlin_24 + + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_tensor(f"{prefix}.B_24") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_tensor(f"{prefix}.B_meta") + s = weights.get_tensor(f"{prefix}.s") + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_tensor(f"{prefix}.B") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + s = weights.get_tensor(f"{prefix}.s") + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + if self.is_marlin_24: + B = weights.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = weights.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + B = weights.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_marlin_24: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized" + ) + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_row(self, weights: Weights, prefix: str): + if self.is_marlin_24: + try: + B = weights.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + + return weight + + +@dataclass +class MarlinWeight(Weight): + """ + Marlin weights. + + Attributes: + B (torch.Tensor): int4-quantized weights packed into int32. + s (torch.Tensor): bfloat16/float16 scales. + """ + + B: torch.Tensor + s: torch.Tensor + + def __post_init__(self): + assert self.B.dtype == torch.int32 + assert self.s.dtype in [torch.float16, torch.bfloat16] + + def get_linear(self, bias: torch.Tensor): + return MarlinLinear(weight=self, bias=bias) + + +class MarlinLinear(nn.Module): + def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + in_features = weight.B.shape[0] * MARLIN_TILE_SIZE + out_features = weight.s.shape[1] + assert ( + in_features % 128 == 0 + ), f"Number of input features ({in_features}) not divisable by 128" + assert ( + out_features % 256 == 0 + ), f"Number of output features ({out_features}) not divisable by 256" + + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + assert groupsize in { + -1, + 128, + }, f"Group size must be -1 or 128, was {groupsize}" + + self.B = weight.B + self.s = weight.s + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=weight.B.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + C = marlin_kernels.marlin_gemm( + A.view(-1, A.shape[-1]), + self.B, + self.s, + self.workspace, + A.shape[0], + self.s.shape[1], + A.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 +GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_TILE_SIZE = 16 + + +@dataclass +class GPTQMarlin24Weight: + """ + GPTQ-Marlin 2:4 weights. + + Attributes: + B (torch.Tensor): int4-quantized weights packed into int32. + B_meta (torch.Tensor): metadata for 2:4 sparsity. + s (torch.Tensor): float16 scales. + bits: quantized weight size. + """ + + B: torch.Tensor + B_meta: torch.Tensor + s: torch.Tensor + bits: int + + def __post_init__(self): + assert self.B.dtype == torch.int32 + assert self.B_meta.dtype == torch.int16 + assert self.s.dtype == torch.float16 + + def get_linear(self, bias: torch.Tensor): + return GPTQMarlin24Linear( + weight=self, + bias=bias, + ) + + +class GPTQMarlin24Linear(nn.Module): + def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + supported_bits = ", ".join( + str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + ) + raise RuntimeError( + f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" + ) + + in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 + out_features = weight.s.shape[1] + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + + if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: + supported_sizes = ", ".join( + str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ) + raise RuntimeError( + f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" + ) + + self.bits = weight.bits + weights_per_int32 = 32 // self.bits + + assert ( + out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 + ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" + assert ( + out_features % weights_per_int32 == 0 + ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" + + assert ( + in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 + ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" + if groupsize != -1 and in_features % groupsize != 0: + raise ValueError( + f"Number of input features ({in_features}) not divisable by group size ({groupsize})" + ) + + self.B = weight.B + self.B_meta = weight.B_meta + self.s = weight.s + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, + dtype=torch.int, + device=weight.B.device, + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + C = marlin_kernels.gptq_marlin_24_gemm( + A.view(-1, A.shape[-1]), + self.B, + self.B_meta, + self.s, + self.workspace, + self.bits, + A.shape[0], + self.s.shape[1], + A.shape[1], + ) + + C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/util.py b/backends/gaudi/server/text_generation_server/layers/marlin/util.py new file mode 100644 index 000000000..250d17141 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/marlin/util.py @@ -0,0 +1,141 @@ +import functools +from typing import List, Tuple + +import numpy +import torch +from text_generation_server.utils.import_utils import SYSTEM + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +def _check_marlin_kernels(): + if not (SYSTEM == "cuda" and has_sm_8_0): + raise NotImplementedError( + "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." + ) + + if marlin_kernels is None: + raise NotImplementedError( + "marlin is not installed, install it with: pip install server/marlin" + ) + + +# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 +@functools.cache +def get_perms() -> Tuple[List[int], List[int]]: + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def permute_scales(scales: torch.Tensor): + scale_perm, scale_perm_single = get_perms() + out_features = scales.shape[1] + if scales.shape[0] == 1: + scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + else: + scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm] + return scales.reshape((-1, out_features)).contiguous() + + +# Functions below are from vLLM + + +def get_pack_factor(bits: int) -> int: + if 32 % bits != 0: + raise ValueError(f"Cannot {bits} bit values into uint32") + return 32 // bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + scale_perm, _ = get_perms() + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp diff --git a/backends/gaudi/server/text_generation_server/layers/medusa.py b/backends/gaudi/server/text_generation_server/layers/medusa.py new file mode 100644 index 000000000..139c4dc25 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/medusa.py @@ -0,0 +1,191 @@ +import torch +from torch import nn +from typing import Tuple, Optional +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.layers.linear import FastLinear +from text_generation_server.layers.tensor_parallel import ( + TensorParallelHead, + TensorParallelColumnLinear, +) + + +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load( + config, prefix=f"{prefix}.linear", weights=weights, bias=True + ) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__(self, config, medusa_config, weights): + super().__init__() + self.heads = torch.nn.ModuleList( + [ + MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) + for i in range(get_speculate()) + ] + ) + + def forward(self, x): + if not self.heads: + return None + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return speculative_logits + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, medusa_config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList( + [ + ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) + for i in range(medusa_config["medusa_num_layers"]) + ] + ) + n = len(self.blocks) + self.out = FastLinear.load( + config, prefix=f"{prefix}.{n}", weights=weights, bias=False + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x + + +class MedusaHeadV1(nn.Module): + def __init__(self, lm_head, medusa): + super().__init__() + self.lm_head = lm_head + self.medusa = medusa + + @staticmethod + def load(config, prefix: str, weights): + from pathlib import Path + from safetensors import safe_open + import json + + speculator = config.speculator + + path = speculator["path"] + medusa_config = str(Path(path) / "config.json") + + for fname in speculator["model_paths"]: + filename = str(Path(path) / fname) + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + medusa = MedusaModel(config, medusa_config, weights) + lm_head = TensorParallelHead.load(config, prefix, weights) + return MedusaHeadV1(lm_head, medusa) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + # If we have too many tokens, we skip speculative logits + if input.shape[0] > 128: + return logits, None + + speculative_logits = self.medusa(input) + return logits, speculative_logits + + +class MedusaHeadV2(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + from pathlib import Path + from safetensors import safe_open + import json + + speculator_path = config.speculator["path"] + + medusa_config = str(Path(speculator_path) / "config.json") + filename = str(Path(speculator_path) / "medusa_lm_head.safetensors") + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + self.n_medusa_heads = get_speculate() + + assert medusa_config["medusa_num_layers"] == 1 + self.linear = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)], + dim=0, + weights=weights, + bias=True, + ) + self.process_group = weights.process_group + self.world_size = self.process_group.size() + self.rank = self.process_group.rank() + + self.act = torch.nn.SiLU() + + self.lm_head = TensorParallelHead.load(config, prefix, weights) + + def forward(self, x): + # If we have too many tokens, we skip speculative logits + if x.shape[0] > 128: + logits = self.lm_head(x) + return logits, None + + size = x.shape[-1] + block_size = (size + self.world_size - 1) // self.world_size + start = self.rank * block_size + stop = (self.rank + 1) * block_size + + x_block = x[:, start:stop] + + # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1 + medusa_res = self.act(self.linear(x)).reshape( + *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] + ) + + # Apply all residual medusa heads + output = x[:, start:stop].unsqueeze(-2) + medusa_res + + # Gather medusa heads + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + + # Stack x and medusa residual x + stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2) + + # Compute lm head on x + medusa residual x + logits = self.lm_head(stacked_x) + + # Finally, split logits from speculative logits + logits, speculative_logits = torch.split( + logits, [1, self.n_medusa_heads], dim=-2 + ) + # Squeeze added dimension + logits = logits.squeeze(-2) + + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/layers/mlp.py b/backends/gaudi/server/text_generation_server/layers/mlp.py new file mode 100644 index 000000000..d33b41f32 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/mlp.py @@ -0,0 +1,282 @@ +import torch +import math +from torch import nn +from torch.nn import functional as F +from typing import Optional, Tuple +from text_generation_server.layers import TensorParallelEmbedding, FastLinear +from text_generation_server.layers.tensor_parallel import TensorParallelHead +from text_generation_server.utils.speculate import get_speculate + + +class MLPSpeculatorLayerNorm(nn.Module): + """ + A L2 normalization implementation + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + elementwise_scale_weight : torch.Tensor + learned scaling term after normalization? + elementwise_shift_bias : torch.Tensor + learned bias term after normalization? + eps : float + Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). + """ + + def __init__( + self, + prefix, + config, + weights, + eps=1e-06, + ): + super(MLPSpeculatorLayerNorm, self).__init__() + self.weight = weights.get_tensor(f"{prefix}.weight") + self.bias = weights.get_tensor(f"{prefix}.bias") + self.eps = eps + + def forward(self, x): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + x = self.weight * x + x = x + self.bias + return x + + +INV_SQRT2 = 2**-0.5 + + +def simple_norm(x: torch.Tensor, eps=1e-06): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps) + x = xf.type_as(x) + return x * INV_SQRT2 + + +class MLPSpeculatorModelTied(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.config = config + self.n_predict = get_speculate() + self.hidden_size = config.hidden_size + + self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights) + self.proj0 = FastLinear.load( + config, + prefix=f"{prefix}.proj.0", + weights=weights, + bias=False, + ) + self.proj1 = FastLinear.load( + config, + prefix=f"{prefix}.proj.1", + weights=weights, + bias=False, + ) + self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False) + self.ln = MLPSpeculatorLayerNorm( + prefix=f"{prefix}.ln.0", + config=config, + weights=weights, + ) + + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation + self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 + self.activation = nn.GELU() + self.vsize = config.vocab_size + self.inner_dim = config.speculator_config["inner_dim"] + self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + ): + top_k_tokens_per_head = self.top_k_tokens_per_head + + # k indicates # of candidates + # h indicates # of generated tokens + state = hidden_states + b = state.size(0) + ind = input_ids.unsqueeze(0) + all_probs = torch.empty( + b, self.n_predict, self.vsize, device=state.device + ) # b k h v + assert ( + len(top_k_tokens_per_head) == self.n_predict + ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" + for i in range(self.n_predict): + # Project and predict + z = self.emb(ind) + # z = z.mul(self.emb_weight) # b k d + if i == 0: + state = self.proj0(state) * self.state_weight + z + else: + state = self.proj1(state) * self.state_weight + z + state = self.activation(self.ln(state)) # b k d + probs = F.log_softmax(self.head(state), dim=-1) # b k v + _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' + + # Update candidate set with new predictions + + # Update distribution set with new logits + all_probs[:, i] = probs.exp() + + # Update state, log_probs and ind for new predictions + state = state.unsqueeze(2).expand( + -1, -1, top_k_tokens_per_head[i], -1 + ) # b k k' d + state = state.reshape(-1, b, state.size(3)) # b kk' d + ind = preds.view(-1, b) # b kk' + + speculative_logits = all_probs + return speculative_logits + + +class MLPSpeculatorModel(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.config = config + self.n_predict = get_speculate() + self.hidden_size = config.hidden_size + + self.emb = nn.ModuleList( + [ + TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) + for i in range(self.n_predict) + ] + ) + self.proj = [ + FastLinear.load( + config, + prefix=f"{prefix}.proj.{i}", + weights=weights, + bias=False, + ) + for i in range(self.n_predict) + ] + self.head = nn.ModuleList( + [ + FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) + for i in range(self.n_predict) + ] + ) + self.ln = nn.ModuleList( + [ + MLPSpeculatorLayerNorm( + prefix=f"{prefix}.ln.{i}", + config=config, + weights=weights, + ) + for i in range(self.n_predict) + ] + ) + + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation + self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 + self.activation = nn.GELU() + self.vsize = config.vocab_size + self.inner_dim = config.speculator_config["inner_dim"] + self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + ): + top_k_tokens_per_head = self.top_k_tokens_per_head + + # k indicates # of candidates + # h indicates # of generated tokens + state = hidden_states + b = state.size(0) + ind = input_ids.unsqueeze(0) + all_probs = torch.empty( + b, self.n_predict, self.vsize, device=state.device + ) # b k h v + assert ( + len(top_k_tokens_per_head) == self.n_predict + ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" + for i in range(self.n_predict): + # Project and predict + z = self.emb[i](ind) + # z = z.mul(self.emb_weight) # b k d + state = self.proj[i](state) * self.state_weight + z + state = self.activation(self.ln[i](state)) # b k d + probs = F.log_softmax(self.head[i](state), dim=-1) # b k v + _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' + + # Update candidate set with new predictions + + # Update distribution set with new logits + all_probs[:, i] = probs.exp() + + # Update state, log_probs and ind for new predictions + state = state.unsqueeze(2).expand( + -1, -1, top_k_tokens_per_head[i], -1 + ) # b k k' d + state = state.reshape(-1, b, state.size(3)) # b kk' d + ind = preds.view(-1, b) # b kk' + + speculative_logits = all_probs + return speculative_logits + + +class MLPSpeculatorHead(nn.Module): + def __init__(self, lm_head, mlp_speculator, scale_input: bool): + super().__init__() + self.lm_head = lm_head + self.mlp_speculator = mlp_speculator + self.scale_input = scale_input + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + # If we have too many tokens, we skip speculative logits + if input.shape[0] > 128: + return logits, None + + input_ids = logits.argmax(dim=-1) + if self.scale_input: + input = simple_norm(input) + speculative_logits = self.mlp_speculator(input, input_ids) + return logits, speculative_logits + + @staticmethod + def load(config, prefix: str, weights): + from pathlib import Path + from safetensors import safe_open + + speculator_path = config.speculator["path"] + + for fname in config.speculator["model_paths"]: + filename = str(Path(speculator_path) / fname) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + tie_weights = config.speculator_config.get("tie_weights", False) + if tie_weights: + mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights) + else: + mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator + scale_input = config.speculator_config.get("scale_input", False) + lm_head = TensorParallelHead.load(config, prefix, weights) + return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py new file mode 100644 index 000000000..2c46ca02a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py @@ -0,0 +1,257 @@ +from typing import Optional, Protocol, runtime_checkable + +import torch +import torch.nn as nn +from loguru import logger +from transformers.activations import ACT2FN + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader +from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader +from text_generation_server.layers.moe.gptq_marlin import ( + GPTQMarlinSparseMoELayer, + can_use_marlin_moe_gemm, +) +from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + Weights, + UnquantizedWeight, +) + +if SYSTEM == "rocm": + from .fused_moe_rocm import grouped_topk + from vllm.model_executor.layers.fused_moe import fused_topk +elif SYSTEM != "ipex": + from moe_kernels.fused_moe import fused_topk, grouped_topk + + +# NOTE: we are using a protocol here, because multiple inherance is not nice. +# We need `Module`, and `Module` -> some abstract class -> some concrete +# class inheritance is whacky. + + +@runtime_checkable +class MoELayer(Protocol): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + hidden_act: str = "silu", + ): ... + + def forward( + self, x: torch.Tensor, *, gating_output: torch.Tensor + ) -> torch.Tensor: ... + + +class DenseMoELayer(nn.Module): + """ + Layer for MoE that applies *all* experts to each tokens and then weights + their outputs based on the calculated routing. This layer is much slower + than `SparseMoELayer` and should only be used when no fused kernels are + available (e.g. for unsupported quantizers). + """ + + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + hidden_act: str = "silu", + ): + super().__init__() + + log_once( + logger.info, + "No fused layers are available for this model type, using (slower) dense MoE layer", + ) + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.n_experts = n_experts + self.renormalize = renormalize + self.topk = topk + self.topk_group = topk_group + + if "gelu" in hidden_act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" + if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none" + ), + ) + elif "silu" in hidden_act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[hidden_act] + + self.gate_proj = [ + TensorParallelColumnLinear.load( + None, + prefix=f"{prefix}.{i}.{gate_proj_name}", + weights=weights, + bias=False, + ) + for i in range(self.n_experts) + ] + self.up_proj = [ + TensorParallelColumnLinear.load( + None, + prefix=f"{prefix}.{i}.{up_proj_name}", + weights=weights, + bias=False, + ) + for i in range(self.n_experts) + ] + self.down_proj = [ + TensorParallelRowLinear.load( + None, + prefix=f"{prefix}.{i}.{down_proj_name}", + weights=weights, + bias=False, + ) + for i in range(self.n_experts) + ] + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gating_output: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + if self.n_expert_group is not None and self.topk_group is not None: + topk_weights, topk_ids = grouped_topk( + x, + gating_output, + self.topk, + renormalize=self.renormalize, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + else: + topk_weights, topk_ids = fused_topk( + x, gating_output, self.topk, self.renormalize + ) + topk_weights = topk_weights.to(x.dtype) + + weights = torch.zeros( + topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device + ) + + weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype)) + + out = torch.zeros_like(x) + for i in range(self.n_experts): + h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x) + h = self.down_proj[i](h, reduce=False) + out += h * weights[:, i].view(-1, 1) + + return out + + +class SparseMoELayer(nn.Module): + """ + Layer for MoE that uses fused kernels to only apply the active experts + for each token (rather than applying all experts and selecting the + outputs of active experts). + """ + + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + if ( + isinstance(weights.loader, DefaultWeightsLoader) + and isinstance(weights.loader.weight_class, UnquantizedWeight) + ) or isinstance(weights.loader, HybridFP8UnquantLoader): + cls = UnquantizedSparseMoELayer + elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: + cls = GPTQMarlinSparseMoELayer + else: + raise ValueError( + f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" + ) + + log_once( + logger.info, + "Using MoE layer wih fused gemm", + ) + + self.moe = cls( + n_expert_group=n_expert_group, + n_experts=n_experts, + prefix=prefix, + renormalize=renormalize, + topk=topk, + topk_group=topk_group, + weights=weights, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + down_proj_name=down_proj_name, + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return self.moe(x, gating_output=gating_output) + + @staticmethod + def is_supported(weights: Weights) -> bool: + return ( + ( + isinstance(weights.loader, DefaultWeightsLoader) + and isinstance(weights.loader.weight_class, UnquantizedWeight) + ) + or isinstance(weights.loader, HybridFP8UnquantLoader) + or ( + isinstance(weights.loader, GPTQMarlinWeightsLoader) + and can_use_marlin_moe_gemm( + quant_method=weights.loader.quant_method, + quantize=weights.loader.quantize, + sym=weights.loader.sym, + ) + ) + ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py new file mode 100644 index 000000000..68accb990 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.distributed + + +# TODO: Remove the functions once moe_kernel are built for ROCM +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py b/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py new file mode 100644 index 000000000..3217cdc22 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py @@ -0,0 +1,215 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.marlin.gptq import ( + GPTQMarlinWeight, + GPTQMarlinWeightsLoader, +) + +if SYSTEM == "cuda": + from moe_kernels.fused_marlin_moe import fused_marlin_moe +else: + fused_marlin_moe = None + + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +def can_use_marlin_moe_gemm( + *, + quant_method: str, + quantize: str, + sym: bool, +): + return ( + SYSTEM == "cuda" + and fused_marlin_moe is not None + and has_sm_8_0 + and quantize == "gptq" + and quant_method == "gptq" + and sym + ) + + +@dataclass +class GPTQMarlinMoEWeight: + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: torch.Tensor + perm: torch.Tensor + is_full_k: bool + + +class GPTQMarlinSparseMoELayer(nn.Module): + """ + MoE layer that uses a fused GPTQ-Marlin kernel. + """ + + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + if not ( + isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym + ): + raise ValueError( + f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" + ) + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + + self.gate_up_proj = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + names=[gate_proj_name, up_proj_name], + weights=weights, + ) + + self.down_proj = _load_expert_weights_row( + prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights + ) + + self.bits = weights.loader.bits + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_marlin_moe( + x, + w1=self.gate_up_proj.qweight, + w2=self.down_proj.qweight, + g_idx1=self.gate_up_proj.g_idx, + g_idx2=self.down_proj.g_idx, + perm1=self.gate_up_proj.perm, + perm2=self.down_proj.perm, + w1_scale=self.gate_up_proj.scales, + w2_scale=self.down_proj.scales, + is_full_k1=self.gate_up_proj.is_full_k, + is_full_k2=self.down_proj.is_full_k, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + num_bits=self.bits, + ) + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + names: List[str], + weights: Weights, +) -> GPTQMarlinMoEWeight: + moe_weight = None + for i in range(n_experts): + weight = weights.get_multi_weights_col( + [f"{prefix}.{i}.{name}" for name in names], 0 + ) + assert isinstance(weight, GPTQMarlinWeight) + moe_weight = _pack_weight( + n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight + ) + assert moe_weight is not None + return moe_weight + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> GPTQMarlinMoEWeight: + moe_weight = None + for i in range(n_experts): + weight = weights.get_weights_row( + f"{prefix}.{i}.{name}", + ) + assert isinstance(weight, GPTQMarlinWeight) + moe_weight = _pack_weight( + n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight + ) + assert moe_weight is not None + return moe_weight + + +def _pack_weight( + *, + n_experts: int, + expert: int, + moe_weight: Optional[GPTQMarlinMoEWeight], + weight: GPTQMarlinWeight, +) -> GPTQMarlinMoEWeight: + if moe_weight is None: + qweight = torch.empty( + (n_experts,) + weight.qweight.shape, + dtype=weight.qweight.dtype, + device=weight.qweight.device, + ) + qzeros = torch.empty( + (n_experts,) + weight.qzeros.shape, + dtype=weight.qzeros.dtype, + device=weight.qzeros.device, + ) + scales = torch.empty( + (n_experts,) + weight.scales.shape, + dtype=weight.scales.dtype, + device=weight.scales.device, + ) + g_idx = torch.empty( + (n_experts,) + weight.g_idx.shape, + dtype=weight.g_idx.dtype, + device=weight.g_idx.device, + ) + perm = torch.empty( + (n_experts,) + weight.perm.shape, + dtype=weight.perm.dtype, + device=weight.perm.device, + ) + + moe_weight = GPTQMarlinMoEWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + perm=perm, + is_full_k=weight.is_full_k, + ) + + moe_weight.qweight[expert] = weight.qweight + moe_weight.qzeros[expert] = weight.qzeros + moe_weight.scales[expert] = weight.scales + moe_weight.g_idx[expert] = weight.g_idx + moe_weight.perm[expert] = weight.perm + + return moe_weight diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py new file mode 100644 index 000000000..d9d62c0ef --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -0,0 +1,138 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import UnquantizedWeight, Weights + +if SYSTEM == "rocm": + from vllm.model_executor.layers.fused_moe import fused_moe +elif SYSTEM != "ipex": + from moe_kernels.fused_moe import fused_moe + + +class UnquantizedSparseMoELayer(nn.Module): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + + self.gate_up_proj = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + weights=weights, + ) + + self.down_proj = _load_expert_weights_row( + prefix=prefix, + n_experts=n_experts, + name=down_proj_name, + weights=weights, + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + if SYSTEM == "rocm": + return fused_moe( + x, + self.gate_up_proj, + self.down_proj, + gating_output, + self.topk, + renormalize=self.renormalize, + inplace=True, + ) + + return fused_moe( + x, + w1=self.gate_up_proj, + w2=self.down_proj, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + inplace=True, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + gate_proj_name: str, + up_proj_name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + for i in range(n_experts): + weight = weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + + assert isinstance(weight, UnquantizedWeight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=weight.weight.dtype, + device=weight.weight.device, + ) + + all_weight[i] = weight.weight + + assert all_weight is not None + + return all_weight + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + for i in range(n_experts): + weight = weights.get_weights_row( + f"{prefix}.{i}.{name}", + ) + + assert isinstance(weight, UnquantizedWeight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=weight.weight.dtype, + device=weight.weight.device, + ) + + all_weight[i] = weight.weight + + assert all_weight is not None + + return all_weight diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py new file mode 100644 index 000000000..a2076bb20 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -0,0 +1,548 @@ +import os +import math +import torch +from torch import nn +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "cuda": + import rotary_emb +elif SYSTEM == "rocm": + from vllm._C import ops +elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + +def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + +def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = { + "type": os.environ["ROPE_SCALING"], + "factor": float(os.environ["ROPE_FACTOR"]), + } + return rope_scaling + return getattr(config, "rope_scaling", None) + + +class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() + self.inv_freq = inv_freq + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if SYSTEM == "cuda": + rotary_dim = cos.shape[-1] + q1 = query[..., :rotary_dim] + q2 = query[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., :rotary_dim] + k2 = key[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif SYSTEM == "rocm": + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + ops.rotary_embedding(query, key, head_size, cos, sin, True) + elif SYSTEM == "ipex": + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), True + ) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + @classmethod + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + # `rope_type` is now standard in transformers, but some existing models + # have `type` instead. + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + + if rope_type == "linear": + pass + elif rope_type == "dynamic": + scaling_factor = rope_scaling["factor"] + return DynamicPositionRotaryEmbedding( + dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_type == "llama3": + inv_freq = apply_llama3_scaling( + inv_freq, + scaling_factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + ) + + return cls(inv_freq, scaling_factor) + + elif rope_type == "yarn": + scaling_factor = rope_scaling["factor"] + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + base=base, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, + ) + elif rope_type in ["su", "longrope"]: + short_factor = torch.tensor( + rope_scaling["short_factor"], dtype=torch.float32, device=device + ) + short_inv_freq = 1.0 / ( + short_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + long_factor = torch.tensor( + rope_scaling["long_factor"], dtype=torch.float32, device=device + ) + long_inv_freq = 1.0 / ( + long_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + + original_max_position_embeddings = ( + config.original_max_position_embeddings + ) + max_position_embeddings = config.max_position_embeddings + if max_position_embeddings <= original_max_position_embeddings: + scaling_factor = 1.0 + else: + scale = max_position_embeddings / original_max_position_embeddings + scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(original_max_position_embeddings) + ) + + # if short_mscale and long_mscale are provided we need to scale the freqs + # using the Phi3LongRoPEScaledRotaryEmbedding + if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling): + short_mscale = rope_scaling["short_mscale"] + long_mscale = rope_scaling["long_mscale"] + return Phi3LongRoPEScaledRotaryEmbedding( + short_inv_freq=short_inv_freq, + long_inv_freq=long_inv_freq, + max_position_embeddings=config.max_position_embeddings, + short_mscale=short_mscale, + long_mscale=long_mscale, + original_max_position_embeddings=original_max_position_embeddings, + ) + + return SuRotaryEmbedding( + short_inv_freq=short_inv_freq, + long_inv_freq=long_inv_freq, + scaling_factor=scaling_factor, + original_max_position_embeddings=original_max_position_embeddings, + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + @classmethod + def load(cls, config, prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_scaling["type"] == "yarn": + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + if SYSTEM == "rocm": + # For RoCm, we always use float cos/sin to avoid a cast. + # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 + # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. + dtype = torch.float32 + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + + # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. + return cos.unsqueeze(1), sin.unsqueeze(1) + + +class SuRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + short_inv_freq, + long_inv_freq, + scaling_factor, + original_max_position_embeddings, + ): + super(PositionRotaryEmbedding, self).__init__() + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.dynamic_args = None + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + freqs = torch.cat([short_freqs, long_freqs]) + + self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) + + +class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + short_inv_freq: torch.Tensor, + long_inv_freq: torch.Tensor, + max_position_embeddings: int, + short_mscale: float, + long_mscale: float, + original_max_position_embeddings: int, + ): + super(PositionRotaryEmbedding, self).__init__() + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + self.max_position_embeddings = max_position_embeddings + self.short_mscale = short_mscale + self.long_mscale = long_mscale + self.original_max_position_embeddings = original_max_position_embeddings + + # cache + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.dynamic_args = None + + def _update_cos_sin_cache(self, dtype, device, seqlen): + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + short_freqs = short_freqs * self.short_mscale + long_freqs = long_freqs * self.long_mscale + + freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device) + freqs[: self.original_max_position_embeddings] = short_freqs + freqs[self.original_max_position_embeddings :] = long_freqs + + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq( + self.dim, newbase, self.inv_freq.device + ) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def get_mscale(scale: float = 1.0, mscale: float = 1.0): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings, + base, + device, + scaling_factor, + *, + extrapolation_factor, + attn_factor, + beta_fast, + beta_slow, + mscale: float, + mscale_all_dim: float, + ): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale_all_dim = mscale_all_dim + self.scaling_factor = scaling_factor + self.mscale = float( + get_mscale(self.scaling_factor, mscale) + / get_mscale(self.scaling_factor, mscale_all_dim) + * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings or True: + inv_freq_extrapolation = _create_inv_freq( + self.dim, self.base, self.inv_freq.device + ) + freqs = 1.0 / inv_freq_extrapolation + inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) + low, high = find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.max_position_embeddings, + ) + + inv_freq_mask = ( + 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + + self.inv_freq = inv_freq + + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) + + +def apply_llama3_scaling( + freqs: torch.Tensor, + *, + scaling_factor: int, + low_freq_factor: int, + high_freq_factor: int, + original_max_position_embeddings: int, +): + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + new_freqs = [] + + for freq in freqs: + wavelen = 2 * math.pi / freq + + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scaling_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) + + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) diff --git a/backends/gaudi/server/text_generation_server/layers/speculative.py b/backends/gaudi/server/text_generation_server/layers/speculative.py new file mode 100644 index 000000000..cf8469b53 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/speculative.py @@ -0,0 +1,52 @@ +import torch +import json +from typing import Tuple, Optional +from text_generation_server.layers.tensor_parallel import TensorParallelHead +from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 +from text_generation_server.layers.mlp import MLPSpeculatorHead + + +class SpeculativeHead(torch.nn.Module): + def __init__(self, lm_head, speculator): + super().__init__() + self.head = lm_head + self.speculator = speculator + + @staticmethod + def load(config, prefix: str, weights): + speculator = config.speculator + if speculator: + speculator_path = config.speculator["path"] + speculator_config = str(speculator_path / "config.json") + + with open(speculator_config, "r") as f: + speculator_config = json.load(f) + + config.speculator_config = speculator_config + try: + architecture = speculator_config["architectures"][0] + + if architecture == "MLPSpeculatorPreTrainedModel": + speculator = MLPSpeculatorHead.load(config, prefix, weights) + else: + speculator = None + except KeyError: + try: + speculator = MedusaHeadV1.load(config, prefix, weights) + except Exception: + speculator = MedusaHeadV2(config, prefix, weights) + lm_head = None + else: + lm_head = TensorParallelHead.load(config, prefix, weights) + speculator = None + return SpeculativeHead(lm_head, speculator) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.speculator is not None: + return self.speculator(input) + + assert self.head is not None + logits = self.head(input) + return logits, None diff --git a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py new file mode 100644 index 000000000..13f12ef1e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py @@ -0,0 +1,249 @@ +import torch +from torch.nn import functional as F +from typing import Iterable, List +from text_generation_server.layers.linear import get_linear, FastLinear +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + +class LayerConcat(torch.nn.Module): + """ + Apply multiple layers to the input and concatenate their + outputs. + """ + + def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1): + """ + `dim` is the dimension along which layer outputs are concatenated. + """ + super().__init__() + self.layers = layers + self.dim = dim + + def forward(self, x: torch.Tensor): + outputs = [layer(x) for layer in self.layers] + return torch.cat(outputs, self.dim) + + +class SuperLayer(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group, should_gather: bool): + super().__init__(linear) + self.process_group = process_group + self.should_gather = should_gather + + @staticmethod + def load(config, prefix: str, weights): + if config.quantize == "exl2": + try: + # If the piece and LM head embeddings are shared, we have + # non-quantized weights... + weight = weights.get_tensor(f"{prefix}.weight") + except Exception: + # ...otherwise they are quantized. + weight = weights.get_weights_col(prefix) + should_gather = weights.process_group.size() > 1 + elif weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + + return TensorParallelHead( + get_linear(weight, bias=None), + process_group=weights.process_group, + should_gather=should_gather, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.should_gather: + return super().forward(input) + + world_size = self.process_group.size() + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T + + torch.mm(input, self.linear.weight.T, out=local_out) + if SYSTEM == "ipex": + ipex.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + else: + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + + if input.shape[0] == 1: + return world_out + return world_out.T + + output = super().forward(input) + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + if SYSTEM == "ipex": + ipex.distributed.all_gather(world_output, output, group=self.process_group) + else: + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_gate_up(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_gate_up(prefix) + if bias: + raise NotImplementedError("packed_gate_up only implemented without bias") + else: + bias = None + linear = get_linear(weight, bias) + return cls(linear) + + @classmethod + def load_qkv( + cls, + config, + prefix: str, + weights, + bias: bool, + num_heads: int, + num_key_value_heads: int, + ): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv( + prefix, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + ) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias) + return cls(linear) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_weights_col(prefix) + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + linear = get_linear(weight, bias) + return cls(linear) + + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + if config.quantize == "exl2": + linears = [] + for prefix in prefixes: + weight = weights.get_weights_col(prefix) + b = weights.get_tensor(f"{prefix}.bias") if bias else None + linears.append(get_linear(weight, b)) + linear = LayerConcat(linears) + else: + weight = weights.get_multi_weights_col(prefixes, dim=dim) + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias) + return cls(linear) + + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) + self.process_group = process_group + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_weights_row(prefix) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls( + get_linear(weight, bias), + process_group=weights.process_group, + ) + + def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: + out = super().forward(input) + if self.process_group.size() > 1 and reduce: + if SYSTEM == "ipex": + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class TensorParallelEmbedding(torch.nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = (num_embeddings + world_size - 1) // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = weight.shape[ + 0 + ] # Usually block_size, might be less in non even vocab_size. + self.process_group = weights.process_group + self.reduce = reduce + + """Additional 0 entry used for masking""" + self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = torch.nn.functional.embedding(input, self.weight) + if self.reduce and self.process_group.size() > 1: + if SYSTEM == "ipex": + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) + return out diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py new file mode 100644 index 000000000..dc2cb21cb --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -0,0 +1,323 @@ +import torch +import os + +from loguru import logger +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import modeling_auto +from huggingface_hub import hf_hub_download, HfApi +from typing import Optional +from pathlib import Path +from typing import Optional, List, Dict +# Needed to properly setup habana_frameworks +import text_generation_server.habana_quantization_env as hq_env + +from text_generation_server.utils.speculate import get_speculate, set_speculate +from text_generation_server.models.model import Model +from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.bloom import BLOOM +from text_generation_server.models.starcoder import StarCoder +from text_generation_server.models.vlm_causal_lm import VlmCausalLM +#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM +from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, +) +# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch +# from text_generation_server.models.custom_modeling.mllama import ( +# MllamaForConditionalGeneration, +# ) +from text_generation_server.utils.adapter import ( + AdapterParameters, + build_layer_weight_lookup, + load_and_merge_adapters, + AdapterInfo, +) + + +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + +# Disable gradients +torch.set_grad_enabled(False) + + +def get_model( + model_id: str, + lora_adapter_ids: Optional[List[str]], + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[torch.dtype], + trust_remote_code: bool, + max_input_tokens: int, +) -> Model: + adapt_transformers_to_gaudi() + + if speculate is not None: + set_speculate(speculate) + else: + set_speculate(0) + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict.get("model_type", None) + + speculator = None + if "medusa_num_heads" in config_dict: + medusa_model_id = model_id + medusa_revision = revision + model_id = config_dict["base_model_name_or_path"] + revision = "main" + speculate_medusa = config_dict["medusa_num_heads"] + if speculate is not None: + if speculate > speculate_medusa: + raise RuntimeError( + f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match" + ) + else: + set_speculate(speculate) + else: + set_speculate(speculate_medusa) + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + # Reload model type from parent. + model_type = config_dict.get("model_type", None) + is_local = Path(medusa_model_id).exists() + if not is_local: + medusa_config = hf_hub_download( + medusa_model_id, revision=medusa_revision, filename="config.json" + ) + hf_hub_download( + medusa_model_id, + revision=medusa_revision, + filename="medusa_lm_head.safetensors", + ) + speculator = { + "path": Path(medusa_config).parent, + "model_paths": ["medusa_lm_head.safetensors"], + } + else: + speculator = { + "path": Path(medusa_model_id), + "model_paths": ["medusa_lm_head.safetensors"], + } + + method = "medusa" + elif model_type == "mlp_speculator": + mlp_model_id = model_id + mlp_revision = revision + model_id = config_dict["base_model_name_or_path"] + revision = "main" + speculate_mlp = config_dict["n_predict"] + if speculate is not None: + if speculate > speculate_mlp: + raise RuntimeError( + f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match" + ) + else: + set_speculate(speculate) + else: + set_speculate(speculate_mlp) + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + # Reload model type from parent. + model_type = config_dict.get("model_type", None) + is_local = Path(mlp_model_id).exists() + extension = ".safetensors" + if not is_local: + mlp_speculator_config = hf_hub_download( + mlp_model_id, revision=mlp_revision, filename="config.json" + ) + api = HfApi() + info = api.model_info(mlp_model_id, revision=mlp_revision) + filenames = [ + s.rfilename + for s in info.siblings + if s.rfilename.endswith(extension) + and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename + and "args" not in s.rfilename + and "training" not in s.rfilename + ] + for filename in filenames: + hf_hub_download( + mlp_model_id, + revision=mlp_revision, + filename=filename, + ) + speculator_dir_path = Path(mlp_speculator_config).parent + # if these are downloaded, they get converted to safetensors + filenames.extend( + [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)] + ) + speculator = { + "path": Path(mlp_speculator_config).parent, + "model_paths": filenames, + } + else: + speculator = Path(mlp_model_id) + filenames = [p for p in os.listdir(speculator) if p.endswith(extension)] + speculator = {"path": speculator, "model_paths": filenames} + method = "mlp_speculator" + else: + method = "n-gram" + + speculate = get_speculate() + if speculate > 0: + logger.info(f"Using speculation {method} with {speculate} input ids.") + + model_type = config_dict["model_type"] + + if model_type == "gpt_bigcode": + return StarCoder(model_id=model_id, revision=revision, dtype=dtype) + + if model_type == "bloom": + return BLOOM( + model_id=model_id, + revision=revision, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == "llava_next": + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=None, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + raise ValueError(f"Unsupported model type {model_type}") + + +# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters +# this provides a post model loading hook to load adapters into the model after the model has been loaded +def get_model_with_lora_adapters( + model_id: str, + lora_adapters: Optional[List[AdapterInfo]], + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[torch.dtype], + trust_remote_code: bool, + max_input_tokens: int, + adapter_to_index: Dict[str, int], +): + lora_adapter_ids = [adapter.id for adapter in lora_adapters] + model = get_model( + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + max_input_tokens, + ) + + if len(lora_adapters) > 0: + target_to_layer = build_layer_weight_lookup(model.model) + + for index, adapter in enumerate(lora_adapters): + # The AdapterParameters object allows for merging multiple adapters into a single adapter. + # At the moment, we only support loading a single adapter into the model, but we keep the + # AdapterParameters object for easier extension in the future. + adapter_parameters = AdapterParameters( + adapter_info=[adapter], + # when merging multiple adapters we can weight them differently + # if this is not set, all adapters will be weighted equally + # see: text_generation_server.utils.merges.strategies for impl + weights=None, + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + + adapter_index = index + 1 + adapter_to_index[adapter.id] = adapter_index + + logger.info( + f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}" + ) + weight_names = tuple([v[0] for v in target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + model.model_id, + adapter_parameters, + adapter_index, + weight_names, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + + adapter_layers = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + "qkv_proj", + ] + + for layer_name in adapter_layers: + nlayers = ( + 1 if layer_name == "lm_head" else len(model.model.model.layers) + ) + adapter_weights = LoraWeights.prepare_weights( + config=adapter_config, + module_map=module_map, + layer_type=layer_name, + unused_weight_names=unused_weight_names, + nlayers=nlayers, + dtype=model.dtype, + world_size=model.world_size, + process_group=model.process_group, + target_to_layer=target_to_layer, + ) + + if adapter_weights is None: + continue + + model.layer_to_adapter_weights[layer_name].add_adapter( + adapter_index, adapter_weights + ) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + model.loaded_adapters.add(adapter_index) + + return model diff --git a/backends/gaudi/server/text_generation_server/models/bloom.py b/backends/gaudi/server/text_generation_server/models/bloom.py new file mode 100644 index 000000000..6fe643748 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/bloom.py @@ -0,0 +1,52 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import torch + +from typing import Optional, Type + +from transformers import PreTrainedTokenizerBase + +from text_generation_server.models import CausalLM +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.pb import generate_pb2 + + +class BloomCausalLMBatch(CausalLMBatch): + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "CausalLMBatch": + batch = super().from_pb( + pb=pb, + tokenizer=tokenizer, + dtype=dtype, + device=device, + ) + batch.keys_head_dim_last = False + return batch + + +class BLOOM(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + super(BLOOM, self).__init__( + model_id=model_id, + revision=revision, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return BloomCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py new file mode 100644 index 000000000..5b105224e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -0,0 +1,1399 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import bisect +from dataclasses import dataclass +from functools import wraps +import itertools +import math +import os +import tempfile +import time +import copy +from typing import Dict, List, Optional, Tuple, Type + +import torch +import torch._dynamo +from loguru import logger +from opentelemetry import trace + +import text_generation_server.habana_quantization_env as hq_env +import habana_frameworks.torch as htorch +from optimum.habana.utils import HabanaProfile +from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES +from text_generation_server.utils.chunks import concat_text_chunks +from optimum.habana.checkpoint_utils import ( + get_repo_root, + model_on_meta, + write_checkpoints_json, +) +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizerBase, + AutoConfig, +) + +from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + HeterogeneousNextTokenChooser, + StoppingCriteria, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, +) +from optimum.habana.utils import get_hpu_memory_stats +from text_generation_server.utils.debug import dbg_trace +from text_generation_server.utils.speculate import get_speculate + +tracer = trace.get_tracer(__name__) +MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) +CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) +BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8)) +PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2)) + + +def torch_compile_for_eager(func): + if LAZY_MODE == 1: + return func + return torch.compile( + func, backend="hpu_backend", options={"keep_input_mutations": True} + ) + + +def round_up(number, k): + return (number + k - 1) // k * k + + +def to_tensor_indices(indices, device): + return torch.tensor(indices, dtype=torch.long, device=device) + + +def calculate_chunks(offset): + result = [] + while offset != 0: + sign = 1 if offset > 0 else -1 + best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1] + result.append(best_chunk) + offset = offset - best_chunk + return result + + +def biggest_single_chunk(offset): + if offset != 0: + idx = bisect.bisect(CHUNK_SIZES, abs(offset)) + return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) + else: + return 0 + + +@torch_compile_for_eager +def grouped_pad(tensor_groups, dims, values): + grouped_result = [] + for tensors, dim, value in zip(tensor_groups, dims, values): + padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 + if padding > 0: + assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}" + pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) + result = [ + torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors + ] + else: + result = [t for t in tensors] + grouped_result.append(result) + htorch.core.mark_step() + return grouped_result + + +@torch_compile_for_eager +def roll(tensor, chunk, dim, merge_graphs): + if dim is None: + return tensor + tensor = torch.roll(tensor, chunk, dim) + if not merge_graphs: + htorch.core.mark_step() + return tensor + + +def grouped_roll(tensor_groups, chunk, dims, merge_graphs): + tensor_groups = [ + [roll(t, chunk, dim, merge_graphs) for t in tensors] + for tensors, dim in zip(tensor_groups, dims) + ] + if merge_graphs: + htorch.core.mark_step() + return tensor_groups + + +@torch_compile_for_eager +def grouped_shift(tensor_groups, dims, offset, merge_graphs): + chunks = calculate_chunks(offset) + for c in chunks: + tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs) + return tensor_groups + + +def move(dst_tensors, dst_indices, src_tensors): + bs_dim = 0 + num_indices = dst_indices.size(0) + for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)): + if src_t.size(bs_dim) != num_indices: + src_t = torch.narrow(src_t, bs_dim, 0, num_indices) + dst_t.index_copy_(bs_dim, dst_indices, src_t) + htorch.core.mark_step() + + +def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups): + for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups): + move(dst_tensors, dst_indices, src_tensors) + + +@torch_compile_for_eager +def extend_tensor(tensor, padding, dim): + result = torch.cat([tensor, padding], dim=dim) + htorch.core.mark_step() + return result + + +@torch_compile_for_eager +def extend_batch(tensors, target_bs, dim): + diff = target_bs - tensors[0].size(dim) + # TODO: add support for shrinking bs + if diff <= 0: + return tensors + shape = list(tensors[0].shape) + shape[dim] = diff + padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) + tensors = [extend_tensor(t, padding, dim) for t in tensors] + return tensors + + +def grouped_extend_batch(tensor_groups, target_bs, bs_dims): + tensor_groups = [ + extend_batch(tensors, target_bs, dim) + for tensors, dim in zip(tensor_groups, bs_dims) + ] + return tensor_groups + + +@torch_compile_for_eager +def merge(tensor_group): + tensor_group = [torch.stack(tensor_group)] + htorch.core.mark_step() + return tensor_group + + +@torch_compile_for_eager +def split(tensor_group, clone_data): + tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)] + if clone_data: + tensor_group = [t.clone() for t in tensor_group] + htorch.core.mark_step() + return tensor_group + + +def remove_kv_cache_from_output(module): + orig_fwd = module.forward + + @wraps(orig_fwd) + def forward(*args, **kwargs): + if kwargs["past_key_values"] is not None: + kwargs["return_dict"] = False + output = orig_fwd(*args, **kwargs) + first_value, second_value, *_ = output + if first_value.nelement() < 2: + return second_value + else: + return first_value + else: + kwargs["return_dict"] = True + return orig_fwd(*args, **kwargs) + + module.forward = forward + return module + + +@dataclass +class CausalLMRequest: + idx: int + data: generate_pb2.Request + input_length: int + prefix_offset: int + read_offset: int + stopping_criteria: StoppingCriteria + + all_input_ids: torch.Tensor + + @classmethod + def from_pb( + cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase + ): + return cls( + idx=idx, + data=data, + input_length=None, + prefix_offset=None, + read_offset=None, + stopping_criteria=StoppingCriteria.from_pb( + data.stopping_parameters, tokenizer + ), + all_input_ids=None, + ) + + def update_idx(self, new_idx): + prev = self.idx + self.idx = new_idx + return (new_idx, prev) + + +@dataclass +class CausalLMBatch(Batch): + batch_id: int + requests: List[CausalLMRequest] + + # Decoder values + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + past_key_values: Optional[List[Tuple]] + merged_kv_cache: bool + + # Lengths of all generations present in the batch + input_length: int + + # Generation helpers + next_token_chooser: HeterogeneousNextTokenChooser + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + input_length: int + + # Past metadata + logits = None + past = None + + keys_head_dim_last: bool = True + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.data.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + def detach_kv_cache(self): + past_keys = [past[0] for past in self.past_key_values] + past_values = [past[1] for past in self.past_key_values] + del self.past_key_values + return past_keys, past_values + + def attach_kv_cache(self, past_keys, past_values): + # TODO: Add support for models that don't store kv_cache in a list + self.past_key_values = list(zip(past_keys, past_values)) + + def merge_kv_cache_if_needed(self, target_bs, offset): + pad_needed = self.seq_length < MAX_TOTAL_TOKENS + shift_needed = offset != 0 + expand_needed = target_bs > self.batch_size + # Very simple heuristic to determine whether we should merge tensors + # this needs tuning for other models/scenarios + small_bs = len(self.past_key_values) > self.batch_size + if ( + not self.merged_kv_cache + and small_bs + and (pad_needed or shift_needed or expand_needed) + ): + past_keys, past_values = self.detach_kv_cache() + past_keys = merge(past_keys) + past_values = merge(past_values) + self.attach_kv_cache(past_keys, past_values) + self.merged_kv_cache = True + + def split_kv_cache_if_needed(self, clone_data): + if self.merged_kv_cache: + past_keys, past_values = self.detach_kv_cache() + past_keys = split(past_keys, clone_data) + past_values = split(past_values, clone_data) + self.attach_kv_cache(past_keys, past_values) + self.merged_kv_cache = False + + def get_tensor_groups(self): + past_keys, past_values = self.detach_kv_cache() + seq_dim = -1 + key_dim = -2 if self.keys_head_dim_last else -1 + value_dim = -2 + tensors = [ + [self.input_ids], + [self.attention_mask], + [self.position_ids], + past_keys, + past_values, + ] + # We don't need to align position_ids + seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] + bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) + return tensors, seq_dims, bs_dims + + def set_tensor_groups(self, tensors): + self.input_ids = tensors.pop(0)[0] + self.attention_mask = tensors.pop(0)[0] + self.position_ids = tensors.pop(0)[0] + past_keys = tensors.pop(0) + past_values = tensors.pop(0) + self.attach_kv_cache(past_keys, past_values) + + def realign(self, target_bs, offset, pad_token_id): + tensors, seq_dims, _ = self.get_tensor_groups() + tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0]) + tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache) + self.set_tensor_groups(tensors) + + def expand_bs(self, target_bs): + tensors, _, bs_dims = self.get_tensor_groups() + tensors = grouped_extend_batch(tensors, target_bs, bs_dims) + self.set_tensor_groups(tensors) + + def used_indices(self): + return [req.idx for req in self.requests] + + def update_indices(self, new_indices): + for req, new_idx in zip(self.requests, new_indices): + req.idx = new_idx + return self.used_indices() + + def free_indices_generator(self): + used = set(req.idx for req in self.requests) + return (i for i in range(self.batch_size) if i not in used) + + def move_data(self, src_batches): + dst_tensors, _, dst_dims = self.get_tensor_groups() + free_indices_gen = self.free_indices_generator() + for src_b in src_batches: + dst_indices = to_tensor_indices( + src_b.update_indices(free_indices_gen), self.input_ids.device + ) + src_tensors, _, src_dims = src_b.get_tensor_groups() + grouped_move(dst_tensors, dst_indices, src_tensors) + self.set_tensor_groups(dst_tensors) + + @classmethod + def recombine( + cls, batches: List["CausalLMBatch"], pad_token_id: int + ) -> "CausalLMBatch": + if not all(b.past_key_values is not None for b in batches): + raise ValueError("KV cache not allocated! Cannot recombine before prefill!") + + total_requests = sum(len(b) for b in batches) + new_bs = total_requests + new_bs = round_up(total_requests, BATCH_BUCKET_SIZE) + + batch_id = batches[0].batch_id + device = batches[0].input_ids.device + + input_lengths = [b.input_length for b in batches] + max_input_length = max(input_lengths) + offsets = [max_input_length - b.input_length for b in batches] + + cur_padding = [b.right_padding for b in batches] + # For prefill there is a space allocated only for first token + # Need to add padding to the max total tokens before first decode + + moves_needed = [ + total_requests - len(b) if b.batch_size == new_bs else total_requests + for b in batches + ] + dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] + reshape = batches[dst_batch_idx].batch_size < new_bs + + # TODO: Add support for changing max seq len, i.e. due to output length bucketing + # FIXME: max_seq_len for non optimized code + if len(batches) > 1: + scenario = "CONCAT" + elif reshape: + scenario = "RESHAPE" + elif cur_padding[dst_batch_idx] <= 0: + scenario = "SHIFT" + offsets = [ + biggest_single_chunk(b.max_input_length - max_input_length) + for b in batches + ] + max_input_length = max_input_length + offsets[dst_batch_idx] + else: + # Nothing to do + return batches[0] + + dbg_trace( + scenario, + f"bs:{[b.batch_size for b in batches]}->{new_bs}" + f" reqs:{[len(b) for b in batches]}" + f" offsets:{offsets}" + f" input_lengths:{input_lengths}" + f" cur_padding:{cur_padding}" + f" dst_batch:{dst_batch_idx}", + ) + + grouped_requests = [[req for req in batch.requests] for batch in batches] + flat_requests = list(itertools.chain(*grouped_requests)) + + for i in range(len(batches)): + target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size + batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) + batches[i].realign(target_bs, offsets[i], pad_token_id) + batches[i].split_kv_cache_if_needed(i == dst_batch_idx) + batches[dst_batch_idx].expand_bs(new_bs) + batches[dst_batch_idx].move_data( + [batches[i] for i in range(len(batches)) if i != dst_batch_idx] + ) + + top_n_tokens = [r.data.top_n_tokens for r in flat_requests] + top_n_tokens.extend([-1] * (new_bs - total_requests)) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + + parameters = [r.data.parameters for r in flat_requests] + # append the dummy parameters for dummy requests + batch_size = batches[dst_batch_idx].batch_size + parameters = pad_next_token_chooser_parameters(parameters, batch_size) + + # update past grammar states + fsm_grammar_states = [0] * batch_size + for batch in batches: + for i, req in enumerate(batch.requests): + fsm_grammar_states[req.idx] = ( + batch.next_token_chooser.fsm_grammar_states[i] + ) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + parameters, + batches[dst_batch_idx].next_token_chooser.dtype, + batches[dst_batch_idx].next_token_chooser.device, + batches[dst_batch_idx].next_token_chooser.tokenizer, + fsm_grammar_states, + quantization_enabled=hq_env.is_quantization_enabled, + ) + + input_ids = batches[dst_batch_idx].input_ids + attention_mask = batches[dst_batch_idx].attention_mask + position_ids = batches[dst_batch_idx].position_ids + past_key_values = batches[dst_batch_idx].past_key_values + input_length = max_input_length + + htorch.core.mark_step() + + return cls( + batch_id=batch_id, + requests=flat_requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_length, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "CausalLMBatch": + dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") + requests = [ + CausalLMRequest.from_pb(idx, req, tokenizer) + for idx, req in enumerate(pb.requests) + ] + inputs = [] + top_n_tokens = [] + + # Parse batch + max_truncation = 0 + for i, r in enumerate(pb.requests): + inputs.append(concat_text_chunks(r.input_chunks.chunks)) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + + max_input_length = max_truncation + if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: + max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF + max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + + # TODO: by tokenizing all inputs at once we loose information on actual input lengths + # this means that we cannot shift inputs to the left after a long input sequence + # was filtered out + new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) + missing_inputs = new_bs - len(inputs) + dummy_inputs = ["?"] * missing_inputs + parameters = [r.parameters for r in pb.requests] + # append the dummy parameters for dummy request + parameters = pad_next_token_chooser_parameters(parameters, new_bs) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + pb=parameters, + dtype=dtype, + device=device, + tokenizer=tokenizer, + quantization_enabled=hq_env.is_quantization_enabled, + ) + + tokenized_inputs = tokenizer( + inputs + dummy_inputs, + return_tensors="pt", + padding="longest", + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ) + + input_len = tokenized_inputs["input_ids"].shape[1] + # Round up sequence length + bucket_size = max_input_length + left_padding = max_input_length - input_len + if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: + assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, ( + "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" + ) + rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 + left_padding = bucket_size - input_len + + input_ids = tokenized_inputs["input_ids"] + attention_mask = tokenized_inputs["attention_mask"] + + # Allocate space for first token + input_ids = torch.nn.functional.pad( + input_ids, (left_padding, 1), value=tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (left_padding, 1), value=0 + ) + all_input_ids = torch.nn.functional.pad( + input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id + ).T.split(1, dim=1) + input_len = bucket_size + for r in requests: + r.input_length = input_len + r.prefix_offset = input_len - 5 + r.read_offset = input_len + r.all_input_ids = all_input_ids[r.idx] + + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + old_bs = len(requests) + top_n_tokens.extend([-1] * (new_bs - old_bs)) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + htorch.core.mark_step() + return cls( + batch_id=pb.id, + requests=requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_len, + ) + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: + dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}") + request_ids = set(request_ids) + self.requests = [req for req in self.requests if req.data.id in request_ids] + return self + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate( + cls, batches: List["CausalLMBatch"], pad_token_id: int = 0 + ) -> "CausalLMBatch": + return cls.recombine(batches, pad_token_id) + + def __len__(self): + return len(self.requests) + + @property + def max_input_length(self): + return max(req.input_length for req in self.requests) + + @property + def batch_size(self): + return self.attention_mask.size(0) + + @property + def seq_length(self): + return self.attention_mask.size(1) + + @property + def right_padding(self): + return self.seq_length - self.input_length + + # Maximum number of tokens this batch will grow to + @property + def max_tokens(self): + max_total_tokens = self.attention_mask.size(1) + return len(self.requests) * max_total_tokens + + +class CausalLM(Model): + def __init__( + self, + model_id: str, + model_class: Optional[Type[torch.nn.Module]] = None, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + batch_class=CausalLMBatch, + ): + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + self.prev_bs = 0 + self.quantize = quantize + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + # Create model + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("hpu") + + if hq_env.is_quantization_enabled: + htorch.core.hpu_set_env() + + if world_size > 1: + model = self.get_deepspeed_model(model_id, dtype, revision) + model = hq_env.prepare_model_for_quantization(model) + else: + get_repo_root(model_id) + + # Check support for rope scaling + model_kwargs = {} + config = AutoConfig.from_pretrained(model_id) + if hasattr(config, "rope_scaling"): + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + model = hq_env.prepare_model_for_quantization(model) + model = model.eval().to(device) + + self.enable_hpu_graph = ( + os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + ) + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + + if model.config.model_type not in [ + "gpt_bigcode" + ]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() + model = remove_kv_cache_from_output(model) + + if self.enable_hpu_graph: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + else: + if LAZY_MODE == 0: + # It is said that "keep_input_mutations" is safe for inference to be done + dbg_trace("TORCH COMPILE", f"Torch compiling of model") + model.model = torch.compile( + model.model, + backend="hpu_backend", + options={"keep_input_mutations": True}, + ) + + model = hq_env.setup_quantization(model) + + if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: + raise ValueError(f"Model type {model.config.model_type} is not supported!") + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + if isinstance(model.config.eos_token_id, int): + tokenizer.pad_token_id = model.config.eos_token_id + elif isinstance(model.config.eos_token_id, list): + tokenizer.pad_token_id = model.config.eos_token_id[0] + else: + raise ValueError( + f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" + ) + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.kwargs = { + "use_cache": True, + "return_dict": True, + } + + if model.config.model_type in [ + "llama", + "mistral", + "starcoder2", + "qwen2", + "falcon", + "gpt_bigcode", + ]: + if model.config.model_type not in ["falcon", "gpt_bigcode"]: + self.kwargs["attn_softmax_bf16"] = True + + if model.config.model_type not in ["gpt_bigcode"]: + self.kwargs["trim_logits"] = True + + if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": + self.kwargs["use_flash_attention"] = True + if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": + self.kwargs["flash_attention_recompute"] = True + + self.speculate = get_speculate() + + super(CausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + ) + + # Create profiler + ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] + record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" + output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") + self.profiling_warmup_steps = ( + int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 + ) + self.profiling_steps = ( + int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + ) + self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) + if self.profiling_steps > 0: + self.hb_profiler = HabanaProfile( + wait=self.profiling_wait_steps, + warmup=self.profiling_warmup_steps, + active=self.profiling_steps, + output_dir=output_dir, + record_shapes=record_shapes, + ) + self.hb_profiler.start() + else: + self.hb_profiler = None + self.step = 0 + + def get_deepspeed_model( + self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None + ) -> torch.nn.Module: + import deepspeed + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + + world_size, rank, local_rank = initialize_distributed_hpu() + model_kwargs = {"revision": revision} + + # Initialize process(es) for DeepSpeed + deepspeed.init_distributed(dist_backend="hccl") + logger.info( + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( + world_size, rank, local_rank + ) + ) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + load_to_meta = model_on_meta(config) + + # Check support for rope scaling + if hasattr(config, "rope_scaling"): + config.rope_scaling = self.get_rope_scaling() + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + if load_to_meta: + # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load + with deepspeed.OnDevice(dtype=dtype, device="meta"): + model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) + else: + get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) + # TODO: revisit placement on CPU when auto-injection is possible + with deepspeed.OnDevice(dtype=dtype, device="cpu"): + model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=dtype, **model_kwargs + ) + model = model.eval() + + # Initialize the model + ds_inference_kwargs = {"dtype": dtype} + ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} + ds_inference_kwargs["enable_cuda_graph"] = False + + if load_to_meta: + # model loaded to meta is managed differently + checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") + write_checkpoints_json(model_id, local_rank, checkpoints_json) + ds_inference_kwargs["checkpoint"] = checkpoints_json.name + model = deepspeed.init_inference(model, **ds_inference_kwargs) + + return model.module + + def get_rope_scaling(self) -> Optional[Dict]: + rope_scaling = os.getenv("ROPE_SCALING", None) + if rope_scaling is None: + return None + + rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) + return {"type": rope_scaling, "factor": float(rope_factor)} + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return CausalLMBatch + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + def decode_token( + self, + all_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: + if is_tokenizer_transparent(self.tokenizer): + new_text = self.tokenizer.decode( + all_input_ids[read_offset:], skip_special_tokens=False + ) + return new_text, read_offset, len(all_input_ids) + else: + return super().decode_token(all_input_ids, prefix_offset, read_offset) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + token_idx, + past_key_values: Optional[List[Tuple]] = None, + bypass_hpu_graph: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "token_idx": token_idx, + } + + # Optimum Habana got "lazy_mode" key-val only supported for llama type of models + if self.model.config.model_type == "llama": + kwargs["lazy_mode"] = LAZY_MODE == 1 + + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + if bypass_hpu_graph != None: + kwargs["bypass_hpu_graphs"] = bypass_hpu_graph + + kwargs.update(self.kwargs) + + if past_key_values is not None and self.model.config.model_type not in [ + "gpt_bigcode" + ]: + return self.model.forward(**kwargs) + else: + outputs = self.model.forward(**kwargs) + return outputs.logits, outputs.past_key_values + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batches: List[CausalLMBatch] + ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + # Results + generations: List[Generation] = [] + prev_batches = [] + requests_to_generate = [] + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + for batch_id, batch in enumerate(batches): + if batch.logits is not None: + logits = batch.logits + past = batch.past + prefill = batch.past_key_values is None + if prefill: + # no right padding for prefill + token_idx_scalar = batch.attention_mask.shape[-1] - 1 + token_idx = torch.tensor(token_idx_scalar).to(self.device) + else: + token_idx_scalar = ( + batch.attention_mask.shape[-1] - batch.right_padding + ) + token_idx = torch.tensor(token_idx_scalar).to(self.device) + + # Select next token + input_length = batch.input_length + if logits.shape[-2] > 1: + next_token_ids, next_token_logprobs, logprobs, _, _ = ( + batch.next_token_chooser( + batch.input_ids, + logits[:, input_length - 1 : input_length, :].squeeze(-2), + self.speculate, + ) + ) + else: + next_token_ids, next_token_logprobs, logprobs, _, _ = ( + batch.next_token_chooser( + batch.input_ids, logits.squeeze(-2), self.speculate + ) + ) + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) + + prev_batches.append( + { + "next_token_ids": next_token_ids, + "next_token_logprobs": next_token_logprobs, + } + ) + + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append( + { + "req": req, + "prev_req_idx": req.idx, + "batch_id": batch_id, + "seed": batch.next_token_chooser.seeds[req_idx], + "do_sample": batch.next_token_chooser.do_sample[req_idx], + "top_n_tokens": batch.top_n_tokens[req_idx], + "top_token_ids": batch_top_token_ids[req_idx], + "top_token_logprobs": batch_top_token_logprobs[req_idx], + "grammar_state": batch.next_token_chooser.fsm_grammar_states[ + req.idx + ], + } + ) + + htorch.core.mark_step() + + # Add new token into input_ids + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask.index_fill_(1, token_idx, 1) + + # Adjust lengths + batch.input_length += 1 + + # Update position_ids + if prefill: + batch.position_ids = ( + torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + ) + else: + batch.position_ids += 1 + # Update past key values + if prefill or self.model.config.model_type in ["gpt_bigcode"]: + batch.past_key_values = past + + htorch.core.mark_step() + + # Stage 2. Prepare new batch for speculative scheduling + if len(batches) > 1: + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) + else: + batch = batches[0] + + prefill = batch.past_key_values is None + + # Check if we need to do any bookkeeping first + if not prefill: + batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) + + scenario = "PREFILL" if prefill else "GENERATE" + if ( + self.enable_hpu_graph + and self.limit_hpu_graph + and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs + ): + self.model.clear_cache() + self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) + dbg_trace( + scenario, + f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", + ) + assert batch.right_padding > 0, "No more room for next token!" + + # Execute batch + if prefill: + # no right padding for prefill + token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + batch.logits, batch.past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph=prefill and self.limit_hpu_graph + if self.enable_hpu_graph + else None, + ) + elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): + # Don't schedule next forward if max_new_tokens for all requests equals 1 + # - we've already generated the first and only needed token in the prefill phase + pass + else: + token_idx = torch.tensor( + batch.attention_mask.shape[-1] - batch.right_padding + ).to(self.device) + input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) + logits = self.forward( + input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph=prefill and self.limit_hpu_graph + if self.enable_hpu_graph + else None, + ) + if self.model.config.model_type in ["gpt_bigcode"]: + batch.logits, batch.past = logits + else: + batch.logits = logits + + htorch.core.mark_step() + + start_decode = time.time_ns() + + # Stage 3. Finish and return previous generations + stopped = len(requests_to_generate) > 0 + for prev_batch in prev_batches: + prev_batch["next_token_logprobs"] = prev_batch[ + "next_token_logprobs" + ].tolist() + prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() + htorch.core.mark_step() + + for req_data in requests_to_generate: + req = req_data["req"] + i = req_data["prev_req_idx"] + prev_batch_id = req_data["batch_id"] + assert len(prev_batches) > prev_batch_id + next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] + next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] + + request = req.data + input_length = req.input_length + prefix_offset = req.prefix_offset + read_offset = req.read_offset + do_sample = req_data["do_sample"] + seed = req_data["seed"] + stopping_criteria = req.stopping_criteria + all_input_ids = req.all_input_ids + next_token_id = next_token_ids_cpu[i] + next_token_logprob = next_token_logprobs[i] + top_n_tokens = req_data["top_n_tokens"] + top_token_ids = req_data["top_token_ids"] + top_token_logprobs = req_data["top_token_logprobs"] + grammar_state = req_data["grammar_state"] + + # Append next token to all tokens + all_input_ids[input_length] = next_token_id + new_input_length = input_length + 1 + + # Generated token + if ( + is_tokenizer_transparent(self.tokenizer) + and len(stopping_criteria.stop_sequence_criterias) == 0 + ): + next_token_text = "" + else: + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[0:new_input_length, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + if is_tokenizer_transparent(self.tokenizer): + output_text = None + else: + output_text = self.decode( + all_input_ids[ + new_input_length + - stopping_criteria.current_tokens : new_input_length, + 0, + ] + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + next_token_logprobs + prefill_token_ids = all_input_ids[0 : new_input_length - 1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id], + [next_token_logprob], + [next_token_text], + [next_token_id in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single_with_past_state( + req.idx, next_token_id, grammar_state + ) + ) + + req.all_input_ids = all_input_ids + req.input_length = new_input_length + req.prefix_offset = prefix_offset + req.read_offset = read_offset + + htorch.core.mark_step() + self.step = self.step + 1 + if self.hb_profiler is not None: + if ( + self.step + > self.profiling_wait_steps + + self.profiling_warmup_steps + + self.profiling_steps + ): + self.hb_profiler.stop() + else: + self.hb_profiler.step() + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch if not stopped else None, (forward_ns, decode_ns) + + def generate_warmup_batch(self, request, seq_len, batch_size): + batch = copy.deepcopy(request.batch) + for req in batch.requests: + req.truncate = seq_len + + for i in range(len(batch.requests) - batch_size): + batch.requests.pop() + + return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) + + def warmup(self, request) -> None: + MAX_TOTAL_TOKENS = request.max_total_tokens + MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens + batch = self.batch_type.from_pb( + request.batch, self.tokenizer, self.dtype, self.device + ) + max_prefill_batch_size = batch.input_ids.shape[0] + try: + # max prefill batch size warmup + _, prefill_batch, _ = self.generate_token([batch]) + except: + raise RuntimeError( + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + del prefill_batch + + # Warmup prefill batch_size + max_input_length = request.max_input_length + prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)] + prefill_batch_size_list.append(max_prefill_batch_size) + prefill_seqlen_list = [ + seq + for seq in range( + PAD_SEQUENCE_TO_MULTIPLE_OF, + max_input_length, + PAD_SEQUENCE_TO_MULTIPLE_OF, + ) + ] + prefill_seqlen_list.append(max_input_length) + prefill_batch_size_list.sort(reverse=True) + prefill_seqlen_list.sort(reverse=True) + try: + for batch_size in prefill_batch_size_list: + for seq_len in prefill_seqlen_list: + batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) + _, prefill_batch, _ = self.generate_token([batch]) + except: + prefill_batch_size_list.sort() + prefill_seqlen_list.sort() + raise RuntimeError( + f"Not enough memory to run following prefill batch_size." + f"Prefill batch size list:{prefill_batch_size_list}" + f"Prefill sequence length list:{prefill_seqlen_list}" + f"You need to decrease `--max-batch-prefill-tokens`" + ) + prefill_seqlen_list.sort() + prefill_batch_size_list.sort() + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing prefill warmup successfully.\n" + f"Prefill batch size list:{prefill_batch_size_list}\n" + f"Prefill sequence length list:{prefill_seqlen_list}\n" + f"Memory stats: {mem_stats} " + ) + + # warmup decode batch size + max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) + max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE) + decode_batch_size_list = [ + i + for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE) + ] + decode_batch_size_list.append(max_decode_batch_size) + decode_batch_size_list.sort(reverse=True) + + try: + for batch_size in decode_batch_size_list: + batches = [] + iters = math.floor(batch_size / max_prefill_batch_size) + for i in range(iters): + batch = self.generate_warmup_batch( + request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size + ) + _, prefill_batch, _ = self.generate_token([batch]) + batches.append(prefill_batch) + + if batch_size % max_prefill_batch_size != 0: + batch = self.generate_warmup_batch( + request, + PAD_SEQUENCE_TO_MULTIPLE_OF - 1, + batch_size % max_prefill_batch_size, + ) + _, prefill_batch, _ = self.generate_token([batch]) + batches.append(prefill_batch) + + _, decode_batch, _ = self.generate_token(batches) + _, decode_batch, _ = self.generate_token([decode_batch]) + del decode_batch + batches.clear() + + except: + raise RuntimeError( + f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." + f"You need to decrease `--max-batch-total-tokens`" + ) + + decode_batch_size_list.sort() + MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing decode warmup successfully.\n" + f"Decode batch size list:{decode_batch_size_list}\n" + f"Memory stats: {mem_stats} " + ) + + return MAX_BATCH_TOTAL_TOKENS diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py new file mode 100644 index 000000000..e2719fad2 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -0,0 +1,923 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BLOOM model.""" + +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.distributed +import torch.utils.checkpoint +from torch import nn +from torch.nn import LayerNorm +from torch.nn import functional as F + +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers import BloomConfig, PreTrainedModel + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + SpeculativeHead, +) + +CUSTOM_KERNELS_ENABLED = False +if ( + torch.cuda.is_available() + and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" +): + try: + from custom_kernels import fused_bloom_attention_cuda + + CUSTOM_KERNELS_ENABLED = True + except ImportError: + pass + +_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m" +_CONFIG_FOR_DOC = "BloomConfig" + +BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bigscience/bigscience-small-testing", + "bigscience/bloom-560m", + "bigscience/bloom-1b1", + "bigscience/bloom-1b7", + "bigscience/bloom-3b", + "bigscience/bloom-7b1", + "bigscience/bloom", +] + + +def _make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.ones( + (target_length, target_length + past_key_values_length), + dtype=torch.bool, + device=device, + ) + mask = mask.triu(1 + past_key_values_length) + + expanded_mask = mask.unsqueeze(0).expand( + batch_size, target_length, target_length + past_key_values_length + ) + return expanded_mask + + +def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, tgt_length, src_length) + + +def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + powers = torch.arange( + 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32 + ) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange( + 1, + 1 + 2 * num_remaining_heads, + 2, + device=attention_mask.device, + dtype=torch.int32, + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + return alibi + + +# @torch.jit.script +def dropout_add( + x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool +) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + esidual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out = residual + out + return out + + +# @torch.jit.script # this is shit for unknow reasons. +def _split_heads( + fused_qkv: torch.Tensor, num_heads: int, head_dim: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim) + query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1) + + query_layer = query_layer.transpose(1, 2).reshape( + batch_size * num_heads, seq_length, head_dim + ) + key_layer = key_layer.permute(0, 2, 3, 1).reshape( + batch_size * num_heads, head_dim, seq_length + ) + value_layer = value_layer.transpose(1, 2).reshape( + batch_size * num_heads, seq_length, head_dim + ) + + return query_layer, key_layer, value_layer + + +# @torch.jit.script +def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: + """ + Merge heads together over the last dimenstion + + Args: + x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ + # What we want to achieve is: + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // num_heads + + # First view to decompose the batch size + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, num_heads, seq_length, head_dim) + + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, num_heads * head_dim) + + +class BloomAttention(nn.Module): + def __init__(self, prefix, config: BloomConfig, weights): + super().__init__() + + self.pretraining_tp = config.pretraining_tp + self.slow_but_exact = config.slow_but_exact + + self.process_group = weights.process_group + + self.hidden_size = config.hidden_size + self.num_heads = config.n_head + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = 1.0 + + process_group = weights.process_group + if self.num_heads % process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {process_group.size()}" + ) + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=True, + ) + self.dense = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense", weights=weights, bias=True + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) + + @staticmethod + def compute_attention( + fused_qkv: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]], + alibi: torch.Tensor, + attention_mask: torch.Tensor, + head_mask: Optional[torch.Tensor], + beta: float, + inv_norm_factor: float, + num_heads: int, + use_cache: bool, + ): + batch_size, q_length, three_times_hidden_size = fused_qkv.shape + head_dim = three_times_hidden_size // (3 * num_heads) + batch_size * num_heads + + ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that? + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = _split_heads( + fused_qkv, num_heads=num_heads, head_dim=head_dim + ) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + past_key = past_key.view(-1, *past_key.shape[-2:]) + key_layer = torch.cat((past_key, key_layer), dim=2) + past_value = past_value.view(-1, *past_value.shape[-2:]) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + ### + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + attention_scores = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=beta, + alpha=inv_norm_factor, + ) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34` + attn_weights = attention_scores.masked_fill_( + attention_mask, torch.finfo(attention_scores.dtype).min + ) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + input_dtype + ) + + # # [batch_size, num_heads, q_length, kv_length] + # attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs, value_layer, out=query_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = _merge_heads( + context_layer, num_heads=num_heads, head_dim=head_dim + ) + + return context_layer, present, attention_probs + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value( + hidden_states + ) # [batch_size, seq_length, 3 x hidden_size] + batch_size, q_length, _ = fused_qkv.shape + + if layer_past is not None: + past_key, past_value = layer_past + layer_past = ( + past_key.view(-1, *past_key.shape[-2:]), + past_value.view(-1, *past_value.shape[-2:]), + ) + + if CUSTOM_KERNELS_ENABLED: + assert self.training is False, "Only foward pass was implemented" + assert ( + attention_mask.shape[-1] < 4096 + ), "Custom kernel support only up to 4096 tokens" + ( + context_layer, + present, + attention_probs, + ) = fused_bloom_attention_cuda.forward( + fused_qkv, + layer_past, + alibi, + attention_mask, + head_mask, + self.beta, + self.inv_norm_factor, + self.num_heads, + use_cache, + ) + else: + context_layer, present, attention_probs = self.compute_attention( + fused_qkv=fused_qkv, + layer_past=layer_past, + alibi=alibi, + attention_mask=attention_mask, + head_mask=head_mask, + beta=self.beta, + inv_norm_factor=self.inv_norm_factor, + num_heads=self.num_heads, + use_cache=use_cache, + ) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor += residual + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + +class BloomMLP(nn.Module): + def __init__(self, prefix, config: BloomConfig, weights): + super().__init__() + + self.pretraining_tp = config.pretraining_tp + self.slow_but_exact = config.slow_but_exact + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) + self.gelu_impl = torch.nn.GELU(approximate="tanh") + self.hidden_dropout = config.hidden_dropout + + def forward( + self, hidden_states: torch.Tensor, residual: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = torch.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + F.linear( + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[ + :, int(i * slices) : int((i + 1) * slices) + ], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + + # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + intermediate_output += residual + + return intermediate_output + + +class BloomBlock(nn.Module): + def __init__(self, layer_id: int, config: BloomConfig, weights): + super().__init__() + + prefix = f"h.{layer_id}" + self.input_layernorm = LayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.num_heads = config.n_head + self.self_attention = BloomAttention( + prefix=f"{prefix}.self_attention", config=config, weights=weights + ) + self.post_attention_layernorm = LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm + ) + self.hidden_dropout = config.hidden_dropout + + def forward( + self, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attn_outputs = self.self_attention( + layernorm_output, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output, residual) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class BloomPreTrainedModel(PreTrainedModel): + config_class = BloomConfig + base_model_prefix = "transformer" + _no_split_modules = ["BloomBlock"] + + @staticmethod + def _convert_to_standard_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, + num_heads, ...])) + """ + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + @staticmethod + def _convert_to_bloom_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + +class BloomModel(BloomPreTrainedModel): + def __init__(self, config: BloomConfig, weights): + super().__init__(config) + + self.embed_dim = config.hidden_size + self.num_heads = config.n_head + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.word_embeddings = TensorParallelEmbedding( + prefix="word_embeddings", weights=weights + ) + + self.word_embeddings_layernorm = LayerNorm.load( + prefix="word_embeddings_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + # Transformer blocks + self.h = nn.ModuleList( + [ + BloomBlock(layer_id=layer_id, config=config, weights=weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + # Final Layer Norm + self.ln_f = LayerNorm.load( + prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon + ) + + def _prepare_attn_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, + ) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + device=device, + past_key_values_length=past_key_values_length, + ) + + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[-1] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = build_alibi_tensor(attention_mask, self.num_heads) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + if hasattr(self, "tp_rank"): + assert self.num_heads % self.tp_world_size == 0 + block_size = self.num_heads // self.tp_world_size + alibi = alibi[ + :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size + ] + alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past) + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) + else: + alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past) + causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0) + + alibi = alibi.to(hidden_states.dtype) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class BloomForCausalLM(BloomPreTrainedModel): + def __init__(self, prefix: str, config, weights): + super().__init__(config) + self.transformer = BloomModel(config, weights) + + self.lm_head = SpeculativeHead.load( + config, + prefix="word_embeddings", + weights=weights, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + logits, speculative_logits = self.lm_head(hidden_states) + loss = None + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ( + CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ), + speculative_logits, + ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py new file mode 100644 index 000000000..ab824da5b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py @@ -0,0 +1,817 @@ +from typing import Optional, Tuple + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPooling, +) +from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig + +from text_generation_server.layers import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + +class CLIPVisionEmbeddings(nn.Module): + def __init__(self, prefix, config: CLIPVisionConfig, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + # TODO Should we TP this ? + self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding") + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, device=weights.device).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class CLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = ( + input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=True, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_size) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + ] + * 3, + dim=2, + ) + query_states = query_states * self.scale + key_states = self._shape(key_states, -1, bsz) + value_states = self._shape(value_states, -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_size) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + causal_attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +class CLIPMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + def __init__(self, prefix, config: CLIPConfig, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps + ) + self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPPreTrainedModel(nn.Module): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CLIPConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, prefix, config: CLIPConfig, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + CLIPEncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + """ + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + ) + + return hidden_states + + +class CLIPTextTransformer(nn.Module): + def __init__(self, prefix: str, config: CLIPTextConfig, weights=None): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + # Initialize weights and apply final processing with `self.post_init()` + self.encoder = CLIPEncoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + r""" + Returns: + + """ + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == self.eos_token_id + ) + .int() + .argmax(dim=-1), + ] + + return last_hidden_state + + +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] + + def __init__(self, prefix, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(prefix, config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, prefix, config: CLIPVisionConfig, weights): + super().__init__() + self.config = config + + self.embeddings = CLIPVisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.pre_layrnorm = nn.LayerNorm.load( + prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps + ) + self.encoder = CLIPEncoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + # self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + r""" + Returns: + + """ + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + ) + last_hidden_state = encoder_outputs + # pooled_output = last_hidden_state[:, 0, :] + # pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + # pooler_output=pooled_output, + # hidden_states=encoder_outputs, + ) + + +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + + return self.vision_model( + pixel_values=pixel_values, + ) + + +class CLIPModel(nn.Module): + def __init__(self, prefix, config: CLIPConfig, weights): + super().__init__() + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear( + self.vision_embed_dim, self.projection_dim, bias=False + ) + self.text_projection = nn.Linear( + self.text_embed_dim, self.projection_dim, bias=False + ) + self.logit_scale = nn.Parameter( + torch.tensor(self.config.logit_scale_init_value) + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + vision_outputs = self.vision_model( + pixel_values=pixel_values, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + vision_outputs = self.vision_model( + pixel_values=pixel_values, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + return logits_per_image, logits_per_text diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py new file mode 100644 index 000000000..30656038b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -0,0 +1,543 @@ +# coding=utf-8 +# Copyright 2024 Cohere team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.utils.weights import UnquantizedWeight + +if SYSTEM == "cuda": + import dropout_layer_norm +else: + dropout_layer_norm = None + + +class CohereRotary(PositionRotaryEmbedding): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if SYSTEM == "cuda": + import rotary_emb + + q1 = query[..., ::2] + q2 = query[..., 1::2] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., ::2] + k2 = key[..., 1::2] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif SYSTEM == "rocm": + from vllm._C import ops + + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + ops.rotary_embedding(query, key, head_size, cos, sin, False) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), False + ) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + +class CohereLayerNorm(nn.Module): + def __init__(self, prefix, weights, eps): + super().__init__() + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + self.weight = nn.Parameter(weight) + # Fake weights + self.ones = weight.new_ones(weight.shape[1]) + self.eps = eps + + def forward(self, hidden_states): + if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": + hidden_states = hidden_states.reshape( + -1, self.weight.shape[0], self.weight.shape[1] + ) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + hidden_states_minus_mean = hidden_states - mean + variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) + hidden_states = self.weight.to(torch.float32) * hidden_states + hidden_states = hidden_states.view(-1, self.weight.shape[1]) + return hidden_states.to(input_dtype) + + ( + hidden_states, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + None, + self.ones, + None, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + + # Required to apply one weight matrix per head + hidden_states = hidden_states.view( + -1, self.weight.shape[0], self.weight.shape[1] + ) + hidden_states = self.weight * hidden_states + hidden_states = hidden_states.view(-1, self.weight.shape[1]) + + return hidden_states + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=config.attention_bias, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + ) + + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + if config.attention_bias: + w = [ + weights.get_sharded(f"{p}.bias", dim=0) + for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + ] + bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) + else: + bias = None + + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) + + +class FlashCohereAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = CohereRotary.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = CohereLayerNorm( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.k_norm = CohereLayerNorm( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + else: + self.q_norm = None + self.k_norm = None + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=config.attention_bias, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, key, value = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + + if self.use_qk_norm: + query = query.reshape(-1, self.head_size) + key = key.reshape(-1, self.head_size) + query = self.q_norm(query.contiguous()) + key = self.k_norm(key.contiguous()) + + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_key_value_heads, self.head_size) + value = value.view(-1, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, key, cos, sin) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), reduce=False + ) + + +class CohereMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False + ) + + +class FlashCohereLayer(nn.Module): + def __init__(self, prefix: str, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + self.self_attn = FlashCohereAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = FastLayerNorm.load_no_bias( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.process_group = weights.process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + mlp_output = self.mlp(normed_hidden_states) + output = attn_output + mlp_output + + if self.process_group.size() > 1: + torch.distributed.all_reduce(output, group=self.process_group) + + return output, res + + +class FlashCohereModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + FlashCohereLayer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastLayerNorm.load_no_bias( + prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashCohereForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashCohereModel(prefix, config, weights) + try: + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + except RuntimeError: + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix}.embed_tokens", + weights=weights, + ) + self.logit_scale = config.logit_scale + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + logits *= self.logit_scale + if speculative_logits is not None: + speculative_logits *= self.logit_scale + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py new file mode 100644 index 000000000..1137a453f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -0,0 +1,756 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple, Any +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM != "ipex": + from vllm.model_executor.layers.fused_moe import fused_moe + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, + PREFILL_IN_KV_CACHE, +) +from text_generation_server.layers import ( + FastLinear, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) + + +class DbrxAttentionConfig(PretrainedConfig): + def __init__( + self, + attn_pdrop: float = 0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + +class DbrxFFNConfig(PretrainedConfig): + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1, + uniform_expert_assignment: bool = False, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + if uniform_expert_assignment: + raise ValueError("`uniform_expert_assignment = True` is not supported") + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + +class DbrxConfig(PretrainedConfig): + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + router_aux_loss_coef: float = 0.05, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError("tie_word_embeddings is not supported for Dbrx models.") + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def num_key_value_heads(self): + # We can't use the attribute map, since this the number of KV + # heads is not top-level. + return self.attn_config.kv_n_heads + + +def promote_scalar(x: torch.Tensor) -> torch.Tensor: + return x.view(1) if len(x.size()) == 0 else x + + +def load_attention(config, prefix, weights): + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.Wqkv", + weights=weights, + bias=False, + num_heads=config.n_heads, + num_key_value_heads=config.attn_config.kv_n_heads, + ) + + +def _load_experts(config, prefix, weights): + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.ffn_config.ffn_hidden_size % world_size == 0 + ), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards" + + expert_size = config.ffn_config.ffn_hidden_size + block_size = expert_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.ffn_config.moe_num_experts * block_size, config.d_model), + dtype=weights.dtype, + device=weights.device, + ) + + slice_ = weights._get_slice(f"{prefix}") + + for i in range(config.ffn_config.moe_num_experts): + offset = i * expert_size + expert_slice = slice_[start + offset : stop + offset] + + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +def _load_experts_quantized(config, prefix, weights, cls): + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.ffn_config.ffn_hidden_size % world_size == 0 + ), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards" + + expert_size = config.ffn_config.ffn_hidden_size + block_size = expert_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + slice_ = weights._get_slice(f"{prefix}") + + experts = [] + for i in range(config.ffn_config.moe_num_experts): + if config.quantize in ["gptq", "awq"]: + raise NotImplementedError( + "Dbrx does not support gptq/awq quantization yet." + ) + else: + offset = i * expert_size + expert_slice = ( + slice_[start + offset : stop + offset] + .to(dtype=weights.dtype) + .to(device=weights.device) + ) + + if cls == TensorParallelRowLinear: + expert_slice = expert_slice.t().contiguous() + linear = get_linear(expert_slice, None) + experts.append(cls(linear, weights.process_group)) + else: + linear = get_linear(expert_slice, None) + experts.append(cls(linear)) + + return experts + + +class DbrxAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.clip_qkv = config.attn_config.clip_qkv + self.num_heads = config.n_heads + self.hidden_size = config.d_model + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.attn_config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.attn_config.kv_n_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + qkv = self.query_key_value(hidden_states) + if self.clip_qkv is not None: + qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class DbrxNormAttentionNorm(nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.norm_1 = FastLayerNorm.load_no_bias( + prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5 + ) + self.self_attn = DbrxAttention( + prefix=f"{prefix}.attn", config=config, weights=weights + ) + self.norm_2 = FastLayerNorm.load_no_bias( + prefix=f"{prefix}.norm_2", + weights=weights, + eps=1e-5, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + normed_hidden_states, res = self.norm_1(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.norm_2(attn_output, res) + + return normed_attn_res_output, attn_res + + +@torch.jit.script +def select_experts( + gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int +): + # all_probs: (sequence_length, n_experts) and upcast for softmax + all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + # weights, selected_experts: (sequence_length, top-k) + weights, selected_experts = torch.topk(all_probs, top_k, dim=-1) + if moe_normalize_expert_weights: + weights = weights / torch.norm( + weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True + ) + weights = weights.view(-1) + selected_experts = selected_experts.view(-1) + + return selected_experts, weights + + +@torch.jit.script +def round_up(x: torch.Tensor, value: int): + return torch.div(x + (value - 1), value, rounding_mode="trunc") * value + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config: DbrxConfig, weights): + super().__init__() + self.moe_normalize_expert_weights = ( + config.ffn_config.moe_normalize_expert_weights + ) + self.hidden_dim = config.d_model + self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size() + self.num_experts = config.ffn_config.moe_num_experts + self.top_k = config.ffn_config.moe_top_k + + act = config.ffn_config.ffn_act_fn["name"] + if "gelu" in act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + elif "silu" in act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[act] + + # gating + self.gate = FastLinear.load( + config, f"{prefix}.router.layer", weights, bias=False + ) + + # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) + w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + self.wv1 = torch.cat([w1, v1], dim=1) + self.w2 = ( + _load_experts(config, f"{prefix}.experts.mlp.w2", weights) + .view(self.num_experts, self.ffn_dim, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = fused_moe( + x, + self.wv1, + self.w2, + router_logits, + self.top_k, + renormalize=self.moe_normalize_expert_weights, + inplace=True, + ) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DenseMoE(nn.Module): + def __init__(self, prefix, config: DbrxConfig, weights): + super().__init__() + + self.moe_normalize_expert_weights = ( + config.ffn_config.moe_normalize_expert_weights + ) + self.hidden_dim = config.d_model + self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size() + self.num_experts = config.ffn_config.moe_num_experts + self.top_k = config.ffn_config.moe_top_k + + act = config.ffn_config.ffn_act_fn["name"] + if "gelu" in act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + elif "silu" in act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[act] + + # gating + self.gate = FastLinear.load( + config, f"{prefix}.router.layer", weights, bias=False + ) + + self.w1 = _load_experts_quantized( + config, + prefix=f"{prefix}.experts.mlp.w1", + weights=weights, + cls=TensorParallelColumnLinear, + ) + self.w2 = _load_experts_quantized( + config, + prefix=f"{prefix}.experts.mlp.w2", + weights=weights, + cls=TensorParallelRowLinear, + ) + self.v1 = _load_experts_quantized( + config, + prefix=f"{prefix}.experts.mlp.v1", + weights=weights, + cls=TensorParallelColumnLinear, + ) + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + # gate_logits: (sequence_length, n_experts) + gate_logits = self.gate(x) + # all_probs: (sequence_length, n_experts) and upcast for softmax + weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + + if self.top_k < self.num_experts: + _, not_selected_experts = torch.topk( + weights, + self.num_experts - self.top_k, + largest=False, + sorted=False, + dim=1, + ) + # Mask not selected experts + weights.scatter_(1, not_selected_experts, 0) + + # Re-normalize + if self.moe_normalize_expert_weights: + weights = weights / torch.norm( + weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True + ) + weights = weights.to(x.dtype) + + # Final output tensor + out = x.new_zeros(x.shape[0], self.hidden_dim) + for i in range(self.num_experts): + h = self.act(self.w1[i](x)) * self.v1[i](x) + h = self.w2[i](h, reduce=False) + # Add expert output to out with masking + out += h * weights[:, i].view(-1, 1) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class DbrxLayer(nn.Module): + def __init__(self, prefix: str, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.blocks.{layer_id}" + + self.attn = DbrxNormAttentionNorm( + prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights + ) + + moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + self.moe = moe_cls(f"{prefix}.ffn", config, weights) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + # Self Attention + attn_output, attn_res = self.attn( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + moe_output = self.moe(attn_output) + + return moe_output, attn_res + + +class DbrxModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.wte", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DbrxLayer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.n_layers) + ] + ) + self.norm = FastLayerNorm.load_no_bias( + prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5 + ) + + self.head_size = self.layers[0].attn.self_attn.head_size + self.num_heads = self.layers[0].attn.self_attn.num_heads + self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDbrxForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.model = DbrxModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py new file mode 100644 index 000000000..88c2cf803 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -0,0 +1,659 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Type + +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + Seqlen, + attention, + paged_attention, + reshape_and_cache, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weights + +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + + +class DeepseekV2Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekV2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + ) + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV2MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out, reduce=reduce) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class DeepseekV2MoE(nn.Module): + def __init__( + self, + prefix, + config: DeepseekV2Config, + moe_layer_cls: Type[MoELayer], + weights, + ): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.moe_layer = moe_layer_cls( + prefix=f"{prefix}.experts", + n_experts=config.n_routed_experts, + n_expert_group=config.n_group, + renormalize=config.norm_topk_prob, + topk=config.num_experts_per_tok, + topk_group=config.topk_group, + weights=weights, + ) + assert isinstance(self.moe_layer, MoELayer) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + + out = self.moe_layer(x, gating_output=router_logits) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DeepseekV2Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) + else: + self.mlp = DeepseekV2MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV2Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py new file mode 100644 index 000000000..7a3d60c97 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -0,0 +1,560 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from text_generation_server.utils.weights import UnquantizedWeight + + +class Gemma2Config(PretrainedConfig): + def __init__( + self, + vocab_size=256128, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Gemma2FastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix: str, weights, eps=1e-6): + dtype = weights.dtype + weights.dtype = torch.float32 + weight = weights.get_tensor(f"{prefix}.weight") + 1 + weights.dtype = dtype + new = cls(weight, eps) + new.dtype = dtype + return new + + # perform the multiplication in full precision and downcast after + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * self.weight + return hidden_states.to(self.dtype), residual + + +def load_attention(config, prefix: str, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + ) + + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.head_dim + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias=None)) + + +class FlashGemma2Attention(torch.nn.Module): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.head_dim + self.causal = causal + if is_sliding: + self.window_size = config.sliding_window + else: + self.window_size = -1 + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + # self.softmax_scale = self.head_size**-0.5 + self.softmax_scale = config.query_pre_attn_scalar**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + self.softcap = config.attn_logit_softcapping + + query_key_value = load_attention(config, prefix, weights) + self.query_key_value = TensorParallelMultiAdapterLinear.load( + query_key_value, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + self.head_size * config.num_attention_heads, + self.head_size * config.num_key_value_heads, + self.head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + + o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + layer_id, + "o_proj", + process_group=weights.process_group, + ) + + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + ): + qkv = self.query_key_value(hidden_states, adapter_data) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + seqlen, + block_tables, + self.softmax_scale, + causal=self.causal, + window_size_left=self.window_size, + softcap=self.softcap, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + softcap=self.softcap, + ) + + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) + + +class Gemma2MLP(nn.Module): + def __init__(self, prefix, config, weights, layer_id): + super().__init__() + act = config.hidden_activation + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_id, + "down_proj", + process_group=weights.process_group, + ) + + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) + + +class FlashGemma2Layer(nn.Module): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): + super().__init__() + self.self_attn = FlashGemma2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + causal=causal, + is_sliding=is_sliding, + ) + self.mlp = Gemma2MLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id + ) + + self.input_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.pre_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.post_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + ) + + # faster post attention rms norm + normed_attn_res_output, _ = self.post_attention_layernorm(attn_output) + normed_attn_res_output = normed_attn_res_output + res + res = normed_attn_res_output + + pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output) + mlp_output = self.mlp(pre_normed, adapter_data) + post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output) + + return post_hidden_states, normed_attn_res_output + + +class FlashGemma2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights, causal: bool): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.layers = nn.ModuleList( + [ + FlashGemma2Layer( + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, + layer_id=layer_id, + causal=causal, + is_sliding=layer_id % 2 == 0, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = inputs_embeds + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + adapter_data, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashGemma2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): + super().__init__() + + embed_norm = config.hidden_size**0.5 + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.model = FlashGemma2Model( + prefix=prefix, config=config, weights=weights, causal=causal + ) + self.lm_head = SpeculativeHead.load( + prefix=( + f"{prefix}.embed_tokens" + if config.tie_word_embeddings + else f"{prefix}.lm_head" + ), + config=config, + weights=weights, + ) + self.softcap = config.final_logit_softcapping + assert isinstance(self.softcap, float) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( + input_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + + logits /= self.softcap + logits = torch.tanh(logits) + logits *= self.softcap + + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py new file mode 100644 index 000000000..4c1be6f60 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -0,0 +1,471 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, + PREFILL_IN_KV_CACHE, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from text_generation_server.utils.weights import UnquantizedWeight + + +class GemmaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=256128, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class GemmaFastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix: str, weights, eps=1e-6): + dtype = weights.dtype + weights.dtype = torch.float32 + weight = weights.get_tensor(f"{prefix}.weight") + 1 + weights.dtype = dtype + new = cls(weight, eps) + new.dtype = dtype + return new + + # perform the multiplication in full precision and downcast after + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * self.weight + return hidden_states.to(self.dtype), residual + + +def load_attention(config, prefix: str, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + ) + + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.head_dim + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias=None)) + + +class FlashGemmaAttention(torch.nn.Module): + def __init__(self, prefix: str, config, weights, causal: bool): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.head_dim + self.causal = causal + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + seqlen, + block_tables, + self.softmax_scale, + causal=self.causal, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GemmaMLP(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashGemmaLayer(nn.Module): + def __init__(self, prefix: str, config, weights, causal: bool): + super().__init__() + self.self_attn = FlashGemmaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal + ) + self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = GemmaFastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GemmaFastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashGemmaModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights, causal: bool): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.layers = nn.ModuleList( + [ + FlashGemmaLayer( + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, + causal=causal, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = GemmaFastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = inputs_embeds + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashGemmaForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): + super().__init__() + + embed_norm = config.hidden_size**0.5 + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.model = FlashGemmaModel( + prefix=prefix, config=config, weights=weights, causal=causal + ) + self.lm_head = SpeculativeHead.load( + prefix=( + f"{prefix}.embed_tokens" + if config.tie_word_embeddings + else f"{prefix}.lm_head" + ), + config=config, + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( + input_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py new file mode 100644 index 000000000..44c015cf4 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -0,0 +1,461 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) + + +def load_qkv(config, prefix: str, weights, head_size, num_heads): + if config.quantize == "gptq": + return _load_qkv_gptq( + config, + prefix, + weights, + ) + elif config.quantize == "marlin": + raise RuntimeError( + "GPT-2 models with marlin quantization are not yet supported" + ) + else: + return _load_qkv(config, prefix, weights, head_size, num_heads) + + +def _load_qkv_gptq(config, prefix: str, weights): + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + # Weights + weight = weights.get_weights_col_packed_qkv( + f"{prefix}.c_attn", + config.num_attention_heads, + config.num_attention_heads, + ) + + # Bias + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + total_size = shape[0] + assert total_size % 3 == 0, f"Prepacked is not divisible by {3}" + single_size = total_size // 3 + assert single_size % world_size == 0 + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensors = [] + for i in range(3): + tensor = slice_[start + i * single_size : stop + i * single_size] + tensors.append(tensor) + bias = torch.cat(tensors, dim=0) + bias = bias.to(device=weights.device) + + return TensorParallelColumnLinear(get_linear(weight, bias)) + + +def _load_qkv(config, prefix: str, weights, head_size, num_heads): + """Load QKV from a single, transposed matrix.""" + + slice_ = weights._get_slice(f"{prefix}.c_attn.weight") + shape = slice_.get_shape() + total_size = shape[1] + assert total_size % 3 == 0, f"Prepacked is not divisible by {3}" + world_size = weights.process_group.size() + single_size = total_size // 3 + assert single_size % world_size == 0 + rank = weights.process_group.rank() + + # Weights + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensors = [] + for i in range(3): + tensor = slice_[:, start + i * single_size : stop + i * single_size] + tensors.append(tensor) + weight = torch.cat(tensors, dim=1).T + weight = weight.to(dtype=weights.dtype) + weight = weight.to(device=weights.device) + + # Bias + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + total_size = shape[0] + single_size = total_size // 3 + block_size = single_size // world_size + assert single_size % world_size == 0 + start = rank * block_size + stop = (rank + 1) * block_size + b = [] + for i in range(3): + tensor = slice_[start + i * single_size : stop + i * single_size] + b.append(tensor) + bias = torch.cat(b, dim=0) + bias = bias.to(dtype=weights.dtype) + bias = bias.to(device=weights.device) + assert list(bias.shape) == [ + 3 * num_heads * head_size + ], f"{weight.shape} != {[3 * num_heads * head_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias)) + + +def load_row(config, prefix: str, weights, bias: bool): + """load_row, but with transposed weight matrices.""" + + if config.quantize == "gptq": + weight = weights.get_weights_row(prefix) + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=0).T + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + return TensorParallelRowLinear( + get_linear(weight, bias), process_group=weights.process_group + ) + + +def load_col(config, prefix: str, weights, bias: bool): + """load_col, but with transposed weight matrices.""" + if config.quantize == "gptq": + weight = weights.get_multi_weights_col([prefix], dim=1) + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=1).T + + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + + return TensorParallelColumnLinear(get_linear(weight, bias)) + + +class FlashGPT2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + + self.head_size = self.hidden_size // self.num_heads + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.query_key_value = load_qkv( + config, + prefix=prefix, + weights=weights, + head_size=self.head_size, + num_heads=self.num_heads, + ) + + self.o_proj = load_row( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + ) + + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + + def forward( + self, + hidden_states, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + query, key, value = self.query_key_value(hidden_states).split( + self.head_size * self.num_heads, dim=1 + ) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GPT2MLP(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + act = config.activation_function + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + self.c_fc = load_col( + config, prefix=f"{prefix}.c_fc", weights=weights, bias=True + ) + self.c_proj = load_row( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + ) + + intermediate_size = ( + config.n_inner if config.n_inner is not None else 4 * config.hidden_size + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + return self.c_proj(hidden_states) + + +class FlashGPT2Layer(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.self_attn = FlashGPT2Attention( + prefix=f"{prefix}.attn", config=config, weights=weights + ) + self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.ln_2", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + def forward( + self, + hidden_states, + residual, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_output = self.self_attn( + hidden_states, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states = attn_output + residual + residual = hidden_states + + hidden_states = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(hidden_states) + + return residual + mlp_output, residual + + +class FlashGPT2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.layers = nn.ModuleList( + [ + FlashGPT2Layer( + prefix=( + f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}" + ), + config=config, + weights=weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.norm = nn.LayerNorm.load( + prefix="ln_f" if not prefix else f"{prefix}.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = inputs_embeds + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class FlashGPT2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=("wte" if not prefix else f"{prefix}.wte"), + weights=weights, + ) + self.embed_positions = TensorParallelEmbedding( + prefix=("wpe" if not prefix else f"{prefix}.wpe"), + weights=weights, + ) + + self.model = FlashGPT2Model(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="wte" if not prefix else f"{prefix}.wte", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + token_embeds = self.embed_tokens(input_ids) + position_embeds = self.embed_positions(position_ids) + inputs_embeds = token_embeds + position_embeds + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py new file mode 100644 index 000000000..aca970044 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -0,0 +1,406 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) + + +def load_attention(config, prefix: str, weights): + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_weights_row(prefix) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias) + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + +class GPTJRotary(PositionRotaryEmbedding): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if SYSTEM == "cuda": + import rotary_emb + + q1 = query[..., ::2] + q2 = query[..., 1::2] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., ::2] + k2 = key[..., 1::2] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif SYSTEM == "rocm": + from vllm._C import ops + + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + ops.rotary_embedding(query, key, head_size, cos, sin, False) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), False + ) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + +class FlashGPTJAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + + self.head_size = self.hidden_size // self.num_heads + self.softmax_scale = self.head_size**-0.5 + self.rotary_dim = config.rotary_dim + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.query_key_value = load_attention( + config, + prefix=prefix, + weights=weights, + ) + + self.o_proj = load_row( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=False, + ) + + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + + self.rotary_emb = GPTJRotary.static( + config=config, + dim=self.rotary_dim, + base=10000, + device=weights.device, + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + query, key, value = self.query_key_value(hidden_states).split( + self.head_size * self.num_heads, dim=1 + ) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + + # Compute rotary embeddings on rotary_ndims + if self.rotary_dim is not None: + self.rotary_emb( + query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin + ) + else: + self.rotary_emb(query, key, cos, sin) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GPTJMLP(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + act = config.activation_function + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + self.fc_in = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc_in", weights=weights, bias=True + ) + + self.fc_out = load_row( + config, + prefix=f"{prefix}.fc_out", + weights=weights, + bias=True, + ) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + return self.fc_out(hidden_states) + + +class FlashGPTJLayer(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.self_attn = FlashGPTJAttention( + prefix=f"{prefix}.attn", config=config, weights=weights + ) + self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + hidden_states, residual = self.input_layernorm(hidden_states, residual) + # Self Attention + attn_output = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + feed_forward_hidden_states = self.mlp(hidden_states) + + return attn_output + feed_forward_hidden_states, residual + + +class FlashGPTJModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.config = config + + self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights) + self.layers = nn.ModuleList( + [ + FlashGPTJLayer( + prefix=( + f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}" + ), + config=config, + weights=weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.ln_f = FastLayerNorm.load( + prefix="ln_f" if not prefix else f"{prefix}.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def forward( + self, + input_ids: Optional[torch.LongTensor], + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states + + +class FlashGPTJForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + self.model = FlashGPTJModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py new file mode 100644 index 000000000..c9ec70cca --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -0,0 +1,668 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import List, Optional, Tuple, Type + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN + +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( + FastRMSNorm, + FastLayerNorm, +) +from text_generation_server.layers import ( + FastLinear, +) +from text_generation_server.utils.weights import ( + Weights, +) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + +if SYSTEM != "ipex": + pass + +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + + +def load_attention(config, prefix: str, weights, layer_id): + # Only defined in granite. + bias = getattr(config, "attention_bias", False) + head_size = config.hidden_size // config.num_attention_heads + sizes = None + prefixes = None + + if config.model_type == "phi3": + base_layer = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv_proj", + weights=weights, + bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + ) + prefixes = ["qkv_proj"] + elif config.model_type == "baichuan": + prefix = f"{prefix}.W_pack" + base_layer = TensorParallelColumnLinear.load_qkv( + config, + prefix=prefix, + weights=weights, + bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + ) + prefixes = [prefix] + else: + prefixes = ["q_proj", "k_proj", "v_proj"] + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] + base_layer = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) + + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + + +@contextmanager +def no_fp8(weights: Weights): + """De-activate fp8 auto conversion for the duration of this context manager""" + weights_loader = weights.weights_loader + if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8: + weights_loader = HybridFP8UnquantLoader( + weights_loader.activation_scale_ub, to_fp8=False + ) + + with weights.use_loader(weights_loader): + yield + + +class FlashLlamaAttention(torch.nn.Module): + def __init__( + self, + index: int, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + # Setting defaults for baichuan custom config which doesn't apply them. + config.rope_theta = getattr(config, "rope_theta", 10000) + config.num_key_value_heads = getattr( + config, "num_key_value_heads", config.num_attention_heads + ) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + if config.num_key_value_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights, index) + self.index = index + + o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + ): + qkv = self.query_key_value(hidden_states, adapter_data) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) + + +class Phi3MoE(nn.Module): + def __init__( + self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights + ): + super().__init__() + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.moe = moe_layer_cls( + prefix=f"{prefix}.experts", + n_experts=config.num_local_experts, + n_expert_group=None, + renormalize=True, + topk=config.num_experts_per_tok, + topk_group=None, + weights=weights, + gate_proj_name="w1", + up_proj_name="w3", + down_proj_name="w2", + ) + + self.process_group = weights.process_group + + def forward(self, x, adapter_data) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = self.moe(x, gating_output=router_logits) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class LlamaMLP(nn.Module): + def __init__(self, prefix, config, weights, index): + super().__init__() + self.hidden_act = config.hidden_act + self.act = ( + ACT2FN[self.hidden_act] + if "gelu" not in self.hidden_act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" + if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none" + ), + ) + ) + prefixes = None + sizes = None + + # Fuse gate and up proj + bias = getattr(config, "mlp_bias", False) + if config.model_type == "phi3": + gate_up_proj = TensorParallelColumnLinear.load_gate_up( + config, + prefix=f"{prefix}.gate_up_proj", + weights=weights, + bias=bias, + ) + else: + prefixes = ["gate_proj", "up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=bias, + ) + + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=bias, + ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) + + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + self.hidden_size = config.hidden_size + + def forward(self, hidden_states, adapter_data): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 + and hidden_states.shape[0] == 1 + and not self.quantize + and self.hidden_size + != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu( + self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 + ) + return self.down_proj(out, adapter_data) + else: + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) + + +class FlashLlamaLayer(nn.Module): + def __init__(self, index, prefix, config, weights): + super().__init__() + + with no_fp8(weights): + self.self_attn = FlashLlamaAttention( + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if config.model_type == "phimoe": + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.dense = Phi3MoE( + f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights + ) + # with moe the layernorms are are not rmsnorms and they have bias + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + else: + self.dense = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, index=index + ) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + cross_attention_states, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.dense(normed_attn_res_output, adapter_data) + + return mlp_output, attn_res + + +class FlashLlamaModel(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + # Skip fp8 quant for first and last layers + self.layers = nn.ModuleList() + self.cross_attention_layers = getattr(config, "cross_attention_layers", []) + with no_fp8(weights): + self.layers.append( + FlashLlamaLayer( + index=0, + prefix=( + "model.layers.0" if not prefix else f"{prefix}.model.layers.0" + ), + config=config, + weights=weights, + ) + ) + + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1): + if layer_id in self.cross_attention_layers: + from text_generation_server.models.custom_modeling.mllama import ( + FlashLlamaCrossLayer, + ) + + self.layers.append( + FlashLlamaCrossLayer( + index=layer_id, + prefix=( + f"model.layers.{layer_id}" + if not prefix + else f"{prefix}.model.layers.{layer_id}" + ), + config=config, + weights=weights, + ) + ) + else: + self.layers.append( + FlashLlamaLayer( + index=layer_id, + prefix=( + f"model.layers.{layer_id}" + if not prefix + else f"{prefix}.model.layers.{layer_id}" + ), + config=config, + weights=weights, + ) + ) + + with no_fp8(weights): + last_layer_id = config.num_hidden_layers - 1 + self.layers.append( + FlashLlamaLayer( + index=last_layer_id, + prefix=( + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}.model.layers.{last_layer_id}" + ), + config=config, + weights=weights, + ) + ) + + self.norm = FastRMSNorm.load( + prefix="model.norm" if not prefix else f"{prefix}.model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + adapter_data, + cross_attention_states=None, + ) -> torch.Tensor: + hidden_states = inputs_embeds + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + adapter_data, + cross_attention_states, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashLlamaForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + with no_fp8(weights): + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" + if not prefix + else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) + self.model = FlashLlamaModel(prefix, config, weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + with no_fp8(weights): + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states=None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + adapter_data=adapter_data, + cross_attention_states=cross_attention_states, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py new file mode 100644 index 000000000..341a23524 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -0,0 +1,535 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) + + +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + + +class MistralConfig(PretrainedConfig): + model_type = "mistral" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class MistralAttention(torch.nn.Module): + def __init__(self, prefix: str, config, weights, layer_id): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + self.query_key_value = TensorParallelMultiAdapterLinear.load( + query_key_value, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + self.head_size * config.num_attention_heads, + self.head_size * config.num_key_value_heads, + self.head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + + o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + layer_id, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + adapter_data, + ): + qkv = self.query_key_value(hidden_states, adapter_data) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + seqlen, + block_tables, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) + + +class MistralMLP(nn.Module): + def __init__(self, prefix: str, config, weights, layer_id): + super().__init__() + self.hidden_act = config.hidden_act + self.act = ( + ACT2FN[self.hidden_act] + if "gelu" not in self.hidden_act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" + if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none" + ), + ) + ) + # Fuse gate and up proj + gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_id, + "down_proj", + process_group=weights.process_group, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states, adapter_data): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu( + self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 + ) + return self.down_proj(out, adapter_data) + else: + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) + + +class MistralLayer(nn.Module): + def __init__(self, prefix: str, config, weights, layer_id): + super().__init__() + self.self_attn = MistralAttention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = MistralMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + adapter_data, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + adapter_data, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output, adapter_data) + + return mlp_output, attn_res + + +class MistralModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.layers = nn.ModuleList( + [ + MistralLayer( + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, + layer_id=layer_id, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, true_max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + adapter_data, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class FlashMistralForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" + super().__init__() + self.embed_tokens = TensorParallelEmbedding( + prefix=( + f"{name}.embed_tokens" + if not prefix + else f"{prefix}.{name}.embed_tokens" + ), + weights=weights, + ) + self.model = MistralModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) + self.lm_head = SpeculativeHead.load( + config, + # TODO dirty hack for idefics2. + prefix=( + "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head" + ), + weights=weights, + ) + self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + seqlen = seqlen.clamp(max=self.max_past_tensor) + + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s, + prefill_cache_indices, + adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py new file mode 100644 index 000000000..5836d30af --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -0,0 +1,542 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Type + +import torch +import torch.distributed +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + Seqlen, + attention, + paged_attention, + reshape_and_cache, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.utils.weights import UnquantizedWeight + + +class MixtralConfig(PretrainedConfig): + model_type = "mixtral" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=None, + num_experts_per_tok=2, + num_local_experts=8, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def promote_scalar(x: torch.Tensor) -> torch.Tensor: + return x.view(1) if len(x.size()) == 0 else x + + +def load_attention(config, prefix: str, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + ) + + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias=None)) + + +def _load_experts(config, prefix: str, mat, weights): + if config.quantize is not None: + raise NotImplementedError("Mixtral does not support weight quantization yet.") + + assert mat in ["w1", "w2", "w3"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.intermediate_size % world_size == 0 + ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.num_local_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "w2": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class MixtralAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + seqlen, + block_tables, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +@torch.jit.script +def select_experts(gate_logits: torch.Tensor, top_k: int): + # all_probs: (sequence_length, n_experts) and upcast for softmax + all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + # weights, selected_experts: (sequence_length, top-k) + weights, selected_experts = torch.topk(all_probs, top_k, dim=-1) + weights /= weights.sum(dim=-1, keepdim=True) + weights = weights.view(-1) + selected_experts = selected_experts.view(-1) + + return selected_experts, weights + + +@torch.jit.script +def round_up(x: torch.Tensor, value: int): + return torch.div(x + (value - 1), value, rounding_mode="trunc") * value + + +class MixtralMoE(nn.Module): + def __init__( + self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights + ): + super().__init__() + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.moe = moe_layer_cls( + n_expert_group=None, + n_experts=config.num_local_experts, + prefix=f"{prefix}.experts", + renormalize=True, + topk=config.num_experts_per_tok, + topk_group=None, + weights=weights, + gate_proj_name="w1", + up_proj_name="w3", + down_proj_name="w2", + ) + assert isinstance(self.moe, MoELayer) + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = self.moe(x, gating_output=router_logits) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class MixtralLayer(nn.Module): + def __init__(self, prefix: str, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = MixtralAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + + moe_layer_cls = ( + SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer + ) + self.moe = MixtralMoE( + f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + moe_output = self.moe(normed_attn_res_output) + + return moe_output, attn_res + + +class MixtralModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) + + self.layers = nn.ModuleList( + [ + MixtralLayer( + "model" if not prefix else f"{prefix}.model", + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix="model.norm" if not prefix else f"{prefix}.model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, true_max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashMixtralForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + self.model = MixtralModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + seqlen = seqlen.clamp(max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py new file mode 100644 index 000000000..ad4e382fe --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -0,0 +1,425 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig +from typing import Optional, List, Tuple +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.utils.weights import UnquantizedWeight + + +class GPTNeoXConfig(TransformersGPTNeoXConfig): + attribute_map = { + "num_key_value_heads": "num_attention_heads", + } + + +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_weights_row(prefix) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias) + if config.use_parallel_residual: + return linear + else: + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + +def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): + weight = weights.get_multi_weights_col([prefix], dim=0) + if isinstance(weight, UnquantizedWeight): + # Only on non quantized versions + weight.weight = ( + weight.weight.view( + num_heads, + 3, + head_size, + hidden_size, + ) + .permute(1, 0, 2, 3) + .reshape(-1, hidden_size) + ) + + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) + + linear = get_linear(weight, bias) + if config.use_parallel_residual: + return linear + else: + return TensorParallelColumnLinear(linear) + + +class FlashNeoxAttention(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + num_heads = config.num_attention_heads + hidden_size = config.hidden_size + + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.rotary_dim = int(config.rotary_pct * self.head_size) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.rotary_dim, + base=config.rotary_emb_base, + device=weights.device, + ) + + self.softmax_scale = self.head_size ** (-0.5) + + self.query_key_value = load_qkv( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + num_heads=self.num_heads, + head_size=self.head_size, + hidden_size=self.hidden_size, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=True + ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + + # Compute rotary embeddings on rotary_ndims + query_rot = qkv[:, 0][..., : self.rotary_dim] + query_pass = qkv[:, 0][..., self.rotary_dim :] + key_rot = qkv[:, 1][..., : self.rotary_dim] + key_pass = qkv[:, 1][..., self.rotary_dim :] + + # Inplace rotary + self.rotary_emb(query_rot, key_rot, cos, sin) + qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) + qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) + + reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + qkv[:, 0], + kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], + kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + qkv[:, 0], + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class FlashMLP(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = load_row( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class FlashNeoXLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + + layer_norm_eps = config.layer_norm_eps + + prefix = f"gpt_neox.layers.{layer_id}" + + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps + ) + self.post_attention_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=layer_norm_eps, + ) + self.attention = FlashNeoxAttention( + config, prefix=f"{prefix}.attention", weights=weights + ) + + self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) + self.process_group = weights.process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + if self.use_parallel_residual: + ln1_hidden_states, _ = self.input_layernorm(hidden_states) + + attn_output = self.attention( + ln1_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(ln2_hidden_states) + intermediate = mlp_output + attn_output + + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate + hidden_states, None + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.attention( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashGPTNeoXPreTrainedModel(PreTrainedModel): + config_class = GPTNeoXConfig + base_model_prefix = "gpt_neox" + supports_gradient_checkpointing = False + _no_split_modules = None + + +class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): + def __init__(self, prefix: str, config, weights): + super().__init__(config) + self.config = config + + self.embed_in = TensorParallelEmbedding( + prefix=f"{prefix}.embed_in", weights=weights + ) + + self.layers = nn.ModuleList( + [ + FlashNeoXLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = FastLayerNorm.load( + prefix=f"{prefix}.final_layer_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].attention.head_size + self.num_heads = self.layers[0].attention.num_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_in(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.final_layer_norm(hidden_states, residual) + + return hidden_states + + +class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): + def __init__(self, prefix, config, weights): + super().__init__(config) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights) + + self.embed_out = SpeculativeHead.load( + config, prefix="embed_out", weights=weights + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.gpt_neox( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.embed_out(hidden_states) + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py new file mode 100644 index 000000000..0024f2bb9 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +from torch import nn +from typing import Optional, List, Tuple + +from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear +from text_generation_server.layers.attention import Seqlen +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) + + +class PaliGemmaForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + self.post_vision_tower_layernorm = nn.LayerNorm.load( + prefix="vision_tower.vision_model.post_layernorm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) + + self.multi_modal_projector = TensorParallelColumnLinear.load( + config, + prefix="multi_modal_projector.linear", + weights=weights, + bias=True, + ) + + self.vocab_size = config.text_config.vocab_size + self.config = config + + text_config = config.text_config + text_config.speculator = config.speculator + text_config.quantize = config.quantize + self.text_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + # Unused here + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + inputs_embeds = self.text_model.embed_tokens(input_ids) + # TODO This is odd but apparently pali gemma position ids start at 1. + if cu_seqlen_prefill is not None: + max_s += 1 + position_ids += 1 + + if pixel_values is not None: + pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) + image_outputs = self.vision_tower(pixel_values) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) + + # mask where image or padding tokens + mask = input_ids == self.config.image_token_index + + # insert image features into input embeddings + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + ) + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py new file mode 100644 index 000000000..2a0dc6066 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -0,0 +1,418 @@ +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) + + +class PhiConfig(PretrainedConfig): + def __init__( + self, + vocab_size=51200, + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="gelu_fast", # llama uses silu + layer_norm_eps=1e-05, # rms in llama, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + resid_pdrop=0.1, # llama doesn't have this + partial_rotary_factor=0.5, # important difference between llama and phi + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.rope_theta = rope_theta + self.resid_pdrop = resid_pdrop + self.partial_rotary_factor = partial_rotary_factor + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +# this is the same as llama except for Phi uses bias=True +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + ) + + if config.quantize not in ["gptq", "awq", "marlin"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + # this is the same as llama except for Phi uses bias=True + return TensorParallelColumnLinear(get_linear(weight, bias=True)) + + +class FlashPhiAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.softmax_scale = self.head_size**-0.5 + self.rotary_dim = int(config.partial_rotary_factor * self.head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.rotary_dim, + base=config.rope_theta, + device=weights.device, + ) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + # in llama the dense layer is called "o_proj" and has bias=False + self.dense = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.dense", + weights=weights, + bias=True, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + # Compute query, key, value and split + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + + # Reshape query and key for rotary embeddings + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + # NOTE: this is the main difference between Llama and Phi + # in llama the rotary embeddings are applied to the whole query and key. + # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions + # + # Apply partial positional embeddings in place + self.rotary_emb( + query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin + ) + + # Reshape key and value and cache + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class PhiMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + # llama weights are up_proj and down_proj and bias=False + self.up_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.fc1", + weights=weights, + bias=True, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=True, + ) + + def forward(self, hidden_states): + # NOTE: Llama requires the gate up states to an intermediate size + # Phi does not and we can avoid the `view` operation + return self.down_proj(self.act(self.up_proj(hidden_states))) + + +class FlashPhiLayer(nn.Module): + def __init__(self, prefix: str, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + self.self_attn = FlashPhiAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + hidden_states, res = self.input_layernorm(hidden_states, residual) + # Self Attention + attn_output = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states = self.resid_dropout(attn_output).add( + self.resid_dropout(self.mlp(hidden_states)) + ) + + return hidden_states, res + + +class FlashPhiModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + FlashPhiLayer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + self.norm = FastLayerNorm.load( + prefix="model.final_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashPhiForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashPhiModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + + return self.lm_head(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py new file mode 100644 index 000000000..bb585cc4b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py @@ -0,0 +1,254 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-MoE model.""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json", +} + + +class PhiMoEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PhiMoEModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6400): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and + `original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must + be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of + the attention head size and the `original_max_position_embeddings` must be an integer. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `262144`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 16): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.0): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.01): + Amount of noise to add to the router. + + ```python + >>> from transformers import PhiMoEModel, PhiMoEConfig + + >>> # Initializing a Phi-3 style configuration + >>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct") + + >>> # Initializing a model from the configuration + >>> model = PhiMoEModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=4096, + intermediate_size=6400, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + rope_scaling=None, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.01, + input_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + self.input_jitter_noise = input_jitter_noise + + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, " + f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None) + rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None) + original_max_position_embeddings = self.rope_scaling.get( + "original_max_position_embeddings", None + ) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}" + ) + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if ( + not len(rope_scaling_short_factor) + == self.hidden_size // self.num_attention_heads // 2 + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if ( + not len(rope_scaling_long_factor) + == self.hidden_size // self.num_attention_heads // 2 + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) + if not isinstance(rope_scaling_short_mscale, (int, float)): + raise ValueError( + f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}" + ) + if not isinstance(rope_scaling_long_mscale, (int, float)): + raise ValueError( + f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}" + ) + if not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}" + ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py new file mode 100644 index 000000000..02c788d3e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -0,0 +1,394 @@ +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + + +class Qwen2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + seqlen, + block_tables, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class Qwen2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class Qwen2Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + self.self_attn = Qwen2Attention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class Qwen2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + prefix = f"{prefix}.model" if prefix else "model" + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + Qwen2Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, true_max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class Qwen2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + self.model = Qwen2Model(prefix, config, weights) + + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix}.{suffix}" if prefix else suffix, + weights=weights, + ) + self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + seqlen = seqlen.clamp(max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py new file mode 100644 index 000000000..6671d85e2 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -0,0 +1,699 @@ +from typing import List, Optional, Tuple + +import torch +import torch.distributed +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +from text_generation_server.layers import ( + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.layernorm import FastLayerNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, + Seqlen, +) + + +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_weights_row(prefix) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias) + if config.parallel_attn: + return linear + else: + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + +class RWConfig(PretrainedConfig): + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "num_key_value_heads": "n_head_kv", + } + + def __init__( + self, + model_type="RefinedWeb", + vocab_size=250880, + hidden_size=64, + num_hidden_layers=None, + num_attention_heads=None, + num_ln_in_prallel_attention=None, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + num_kv_heads=None, + multi_query=False, + alibi=False, + new_decoder_architecture=None, + bias=False, + parallel_attn=False, + rope_theta=10_000.0, + **kwargs, + ): + if alibi: + raise NotImplementedError( + "alibi is not supported by this version of the model" + ) + + self.model_type = model_type + self.alibi = False + self.rotary = True + self.rope_theta = rope_theta + + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = ( + num_hidden_layers + if num_hidden_layers is not None + else kwargs.pop("n_layer", 2) + ) + self.n_head = ( + num_attention_heads + if num_attention_heads is not None + else kwargs.pop("n_head", 8) + ) + self.layer_norm_epsilon = layer_norm_epsilon + self.num_ln_in_parallel_attn = num_ln_in_prallel_attention + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.bias = bias + self.parallel_attn = parallel_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + if num_kv_heads is not None: + self.n_head_kv = num_kv_heads + else: + old_n_head_kv = kwargs.pop("n_head_kv", None) + if old_n_head_kv is not None: + self.n_head_kv = old_n_head_kv + else: + self.n_head_kv = 1 if multi_query else self.n_head + + if new_decoder_architecture is not None: + self.new_decoder_architecture = new_decoder_architecture + elif model_type == "RefinedWeb": + self.new_decoder_architecture = True + else: + self.new_decoder_architecture = False + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class FlashRWAttention(torch.nn.Module): + def __init__( + self, + config, + prefix: str, + weights, + ): + super().__init__() + self.num_heads = config.n_head + self.num_heads_kv = config.n_head_kv + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + self.rope_theta = config.rope_theta + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=self.rope_theta, + device=weights.device, + ) + self.softmax_scale = self.head_size ** (-0.5) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) + + if self.num_heads_kv == 1: + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) + else: + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + qkv = self.query_key_value(hidden_states) + + # Split query from key_value + query, kv = qkv.split( + [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], + dim=1, + ) + + # Prepare query and key_value for indexing + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_heads_kv, self.head_size) + + # Inplace rotary + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class FlashRWLargeAttention(torch.nn.Module): + def __init__( + self, + config, + prefix: str, + weights, + ): + super().__init__() + + hidden_size = config.hidden_size + num_heads = config.n_head + # num_heads_kv = config.n_head_kv + num_groups = config.n_head_kv + + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + self.num_groups = num_groups + self.rope_theta = config.rope_theta + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=self.rope_theta, + device=weights.device, + ) + self.softmax_scale = self.head_size ** (-0.5) + + # self.num_groups = num_heads // (num_heads_kv * 2) + self.num_heads = num_heads // self.num_groups + # self.num_heads_kv = num_heads_kv // self.num_groups + process_group = weights.process_group + + if process_group.size() > self.num_groups: + raise NotImplementedError( + "Tensor Parallelism is not implemented for world_size > n groups" + ) + if self.num_groups % process_group.size() != 0: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" + ) + + self.num_groups = self.num_groups // process_group.size() + + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) + + self.kv_head_mapping = torch.arange( + 0, self.num_groups, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_heads) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) + + # Split on group dimension + query, kv = qkv.split( + [self.num_heads, 2], + dim=2, + ) + # Merge groups and heads + query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) + + # Inplace rotary + self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) + + reshape_and_cache( + kv[:, :, 0].contiguous(), + kv[:, :, 1].contiguous(), + kv_cache[0], + kv_cache[1], + slots, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.dense( + attn_output.view(-1, self.num_groups * self.num_heads * self.head_size) + ) + + +class FlashMLP(nn.Module): + def __init__(self, config, prefix: str, weights): + super().__init__() + self.act = torch.nn.functional.gelu + + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias + ) + self.dense_4h_to_h = load_row( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias + ) + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class FlashRWLayer(nn.Module): + def __init__( + self, + layer_id, + prefix: str, + config, + weights, + ): + super().__init__() + + parallel_attn = config.parallel_attn + self.parallel_attn = parallel_attn + + prefix = f"{prefix}.h.{layer_id}" + + # NOTE: Falcon 180B uses the ln_attn prefix + ln_prefix = "input_layernorm" + if config.num_hidden_layers == 80: + ln_prefix = "ln_attn" + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.{ln_prefix}", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.self_attention = FlashRWAttention( + config, + prefix=f"{prefix}.self_attention", + weights=weights, + ) + self.post_attention_layernorm = ( + FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + if not parallel_attn + else None + ) + + self.mlp = FlashMLP( + config, + prefix=f"{prefix}.mlp", + weights=weights, + ) + + self.process_group = weights.process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + if self.parallel_attn: + ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_output = self.self_attention( + ln_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + mlp_output = self.mlp(ln_hidden_states) + intermediate = mlp_output + attn_output + + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate, residual + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attention( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + if self.post_attention_layernorm is not None: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashRWLayerNorm(nn.Module): + def __init__(self, config, prefix: str, weights): + super().__init__() + # Falcon2 includes the number of layer norms in the config + # in the case no number of layer norms is provided, we default to 1 + self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) + + # Falcon 180B uses the ln_attn prefix and has 2 layer norms + if config.num_hidden_layers == 80: + self.num_ln = 2 + + if self.num_ln == 1: + self.input_ln = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + elif self.num_ln == 2: + self.ln_attn = FastLayerNorm.load( + prefix=f"{prefix}.ln_attn", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.ln_mlp = FastLayerNorm.load( + prefix=f"{prefix}.ln_mlp", + weights=weights, + eps=config.layer_norm_epsilon, + ) + else: + raise ValueError("Number of layer norms can either be 1 or 2.") + + def forward( + self, + hidden_states, + residual, + ): + if self.num_ln == 1: + ln_hidden_states, residual = self.input_ln(hidden_states, residual) + return ln_hidden_states, ln_hidden_states, residual + elif self.num_ln == 2: + ln_attn, residual = self.ln_attn(hidden_states, residual) + ln_mlp, _ = self.ln_mlp(residual) + return ln_attn, ln_mlp, residual + + +class FlashRWLargeLayer(nn.Module): + def __init__(self, layer_id, prefix: str, config, weights): + super().__init__() + prefix = f"{prefix}.h.{layer_id}" + + self.ln_layer = FlashRWLayerNorm(config, prefix, weights) + + self.self_attention = FlashRWLargeAttention( + config, + prefix=f"{prefix}.self_attention", + weights=weights, + ) + assert config.parallel_attn, "This version doesn't support non parallel_attn" + + self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) + + self.process_group = weights.process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + # Layer norm. + ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual) + + # Self attention. + attn_output = self.self_attention( + ln_attn, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + # MLP. + mlp_output = self.mlp(ln_mlp) + + intermediate = attn_output + mlp_output + + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate, residual + + +class FlashRWPreTrainedModel(PreTrainedModel): + config_class = RWConfig + + +class FlashRWModel(FlashRWPreTrainedModel): + def __init__(self, prefix: str, config, weights): + super().__init__(config) + self.config = config + + self.word_embeddings = TensorParallelEmbedding( + prefix=f"{prefix}.word_embeddings", weights=weights + ) + + if config.new_decoder_architecture: + self.h = nn.ModuleList( + [ + FlashRWLargeLayer(layer_id, prefix, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.cache_size = self.h[0].self_attention.num_groups + else: + self.h = nn.ModuleList( + [ + FlashRWLayer(layer_id, prefix, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.cache_size = self.h[0].self_attention.num_heads_kv + + self.ln_f = FastLayerNorm.load( + prefix=f"{prefix}.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.head_size = self.h[0].self_attention.head_size + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.word_embeddings(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.h): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states + + +class FlashRWForCausalLM(FlashRWPreTrainedModel): + def __init__(self, prefix: str, config, weights): + super().__init__(config) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.transformer = FlashRWModel(prefix, config, weights) + + self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.transformer( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py new file mode 100644 index 000000000..43eb9687f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -0,0 +1,508 @@ +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + SpeculativeHead, + TensorParallelEmbedding, + get_linear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.gptq import GPTQWeightsLoader +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) + + +def load_multi_mqa( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + if config.quantize == "gptq": + return _load_multi_mqa_gptq( + config, prefix, weights, bias, head_size, num_heads, hidden_size + ) + elif config.quantize == "marlin": + raise RuntimeError( + "santacoder models with marlin quantization are not yet supported" + ) + else: + return _load_multi_mqa( + config, prefix, weights, bias, head_size, num_heads, hidden_size + ) + + +def _load_multi_mqa_gptq( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + from text_generation_server.layers.gptq import GPTQWeight + + if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose: + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + slice_ = weights._get_slice(f"{prefix}.c_attn.qweight") + shape = slice_.get_shape() + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + qweight = torch.cat([q_tensor, kv_tensor], dim=1) + qweight = qweight.to(device=weights.device) + + slice_ = weights._get_slice(f"{prefix}.c_attn.scales") + shape = slice_.get_shape() + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + scales = torch.cat([q_tensor, kv_tensor], dim=1) + scales = scales.to(device=weights.device) + + slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros") + shape = slice_.get_shape() + block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert 2 * head_size % (32 // 4) == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size * 4 // 32 :] + qzeros = torch.cat([q_tensor, kv_tensor], dim=1) + qzeros = qzeros.to(device=weights.device) + + loader = weights.weights_loader + assert isinstance(loader, GPTQWeightsLoader) + loader._get_gptq_params(weights) + if loader.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + g_idx = g_idx.to(device=weights.device) + elif loader.quant_method == "awq": + g_idx = None + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + + from text_generation_server.layers.gptq import HAS_EXLLAMA + + weight = GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=loader.bits, + groupsize=loader.groupsize, + use_awq_kernel=loader.quantize == "awq", + use_exllama=HAS_EXLLAMA, + ) + + if bias: + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + block_size = (shape[0] - 2 * head_size) // world_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + bias = torch.cat([q_tensor, kv_tensor], dim=0) + bias = bias.to(device=weights.device) + + return TensorParallelColumnLinear(get_linear(weight, bias)) + else: + raise NotImplementedError("Gptq loading with santacoder is not implemented") + + +def _load_multi_mqa( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): + if any("c_attn" in k for k in weights.routing.keys()): + slice_ = weights._get_slice(f"{prefix}.c_attn.weight") + shape = slice_.get_shape() + world_size = weights.process_group.size() + rank = weights.process_group.rank() + if config.transpose: + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + weight = torch.cat([q_tensor, kv_tensor], dim=1).T + else: + block_size = (shape[0] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + weight = torch.cat([q_tensor, kv_tensor], dim=0) + if bias: + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + block_size = (shape[0] - 2 * head_size) // world_size + assert (shape[0] - 2 * head_size) % world_size == 0 + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + bias = torch.cat([q_tensor, kv_tensor], dim=0) + else: + if config.transpose: + w = [ + weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, + weights.get_tensor(f"{prefix}.kv_attn.weight").T, + ] + weight = torch.cat(w, dim=0) + else: + w = [ + weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), + weights.get_tensor(f"{prefix}.kv_attn.weight"), + ] + weight = torch.cat(w, dim=1) + + if bias: + b = [ + weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), + weights.get_tensor(f"{prefix}.kv_attn.bias"), + ] + bias = torch.cat(b, dim=0) + else: + bias = None + + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + assert list(weight.shape) == [ + (num_heads + 2) * head_size, + hidden_size, + ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" + if bias is not None: + bias = bias.to(dtype=weights.dtype).to(device=weights.device) + assert list(bias.shape) == [ + (num_heads + 2) * head_size + ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" + return TensorParallelColumnLinear(get_linear(weight, bias)) + + +def load_col(config, prefix: str, weights, bias: bool): + if config.transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=1).T + else: + weight = weights.get_multi_weights_col([prefix], dim=0) + + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + return TensorParallelColumnLinear(get_linear(weight, bias)) + + +def load_row(config, prefix: str, weights, bias: bool): + if config.transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=0).T + else: + weight = weights.get_weights_row(prefix) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return TensorParallelRowLinear( + get_linear(weight, bias), process_group=weights.process_group + ) + + +class FlashMQAttention(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + num_heads = config.num_attention_heads + hidden_size = config.hidden_size + + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.softmax_scale = self.head_size ** (-0.5) + + self.c_attn = load_multi_mqa( + config, + prefix=prefix, + weights=weights, + bias=True, + head_size=self.head_size, + hidden_size=hidden_size, + num_heads=self.num_heads, + ) + self.c_proj = load_row( + config, prefix=f"{prefix}.c_proj", weights=weights, bias=True + ) + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) + + def forward( + self, + hidden_states, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + qkv = self.c_attn(hidden_states) + + # Split query from key_value + query, key_value = qkv.split( + [self.head_size * self.num_heads, 2 * self.head_size], dim=1 + ) + + # Prepare query and key_value for indexing + query = query.view(-1, self.num_heads, self.head_size) + key_value = key_value.view(-1, 2, 1, self.head_size) + + reshape_and_cache( + key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], + seqlen, + block_tables, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.activation_function + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + self.c_fc = load_col( + config, prefix=f"{prefix}.c_fc", weights=weights, bias=True + ) + self.c_proj = load_row( + config, prefix=f"{prefix}.c_proj", weights=weights, bias=True + ) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +class Block(nn.Module): + def __init__(self, prefix: str, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.h.{layer_id}" + self.ln_1 = FastLayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + self.ln_2 = FastLayerNorm.load( + prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon + ) + self.self_attn = FlashMQAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.mlp = MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, + hidden_states, + residual, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ): + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.self_attn( + hidden_states, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, residual = self.ln_2(hidden_states, residual) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashSantacoderModel(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.config = config + + self.process_group = weights.process_group + self.wte = TensorParallelEmbedding( + prefix=f"{prefix}.wte", + weights=weights, + reduce=False, + ) + self.wpe = TensorParallelEmbedding( + prefix=f"{prefix}.wpe", + weights=weights, + reduce=False, + ) + + self.layers = nn.ModuleList( + [ + Block( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.ln_f = FastLayerNorm.load( + prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + self.wpe(position_ids) + + if self.process_group.size() > 1: + torch.distributed.all_reduce(hidden_states, group=self.process_group) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states + + +class FlashSantacoderForCausalLM(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + config.transpose = config.architectures[0].startswith("GPT2") + self.model = FlashSantacoderModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, prefix=f"{prefix}.wte", weights=weights + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py new file mode 100644 index 000000000..4975cf225 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -0,0 +1,554 @@ +# coding=utf-8 +# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, + Seqlen, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.layernorm import ( + FastLayerNorm, + FastRMSNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.utils.weights import UnquantizedWeight + + +class Starcoder2Config(PretrainedConfig): + model_type = "starcoder2" + + def __init__( + self, + vocab_size=49152, + hidden_size=3072, + intermediate_size=12288, + num_hidden_layers=30, + num_attention_heads=24, + num_key_value_heads=2, + mlp_type="default", + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=4096, + initializer_range=0.018042, + norm_type="layer_norm", + norm_epsilon=1e-5, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + rope_theta=10000.0, + sliding_window=None, + attention_dropout=0.0, + residual_dropout=0.0, + embedding_dropout=0.0, + use_bias: bool = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.use_bias = use_bias + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_type = mlp_type + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_type = norm_type + self.norm_epsilon = norm_epsilon + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.embedding_dropout = embedding_dropout + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=config.use_bias, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + ) + + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + if config.use_bias: + w = [ + weights.get_sharded(f"{p}.bias", dim=0) + for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + ] + bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) + else: + bias = None + + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) + + +class Starcoder2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=config.use_bias, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + seqlen, + block_tables, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class Starcoder2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.c_fc = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.c_fc", + weights=weights, + bias=config.use_bias, + ) + self.c_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=config.use_bias, + ) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + return self.c_proj(hidden_states) + + +class Starcoder2GatedMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=config.use_bias, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=config.use_bias, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +STARCODER2_NORMALIZATION_CLASSES = { + "layer_norm": FastLayerNorm, + "rms_norm": FastRMSNorm, +} + +STARCODER2_MLP_CLASSES = { + "default": Starcoder2MLP, + "gated": Starcoder2GatedMLP, +} + + +class Starcoder2Layer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = Starcoder2Attention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + + self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon + ) + self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[ + config.norm_type + ].load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.norm_epsilon, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class Starcoder2Model(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + Starcoder2Layer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( + prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, true_max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashStarcoder2ForCausalLM(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Starcoder2Model(prefix, config, weights) + try: + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + except RuntimeError: + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix}.embed_tokens", + weights=weights, + ) + + self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + seqlen = seqlen.clamp(max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py new file mode 100644 index 000000000..a829c3741 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -0,0 +1,839 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics2 model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +import math + +from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, +) +from text_generation_server.layers.attention import Seqlen +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics2VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics2VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics2Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics2EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics2VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics2Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.text_config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + def forward(self, hidden_states): + start_shape = hidden_states.shape[:-1] + gate_up_states = self.gate_up_proj(hidden_states) + intermediate_size = gate_up_states.shape[-1] // 2 + gate_up_states = gate_up_states.view(-1, 2, intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] + ).view(*start_shape, -1) + + +class Idefics2RMSNorm(nn.Module): + def __init__(self, prefix, weights, eps): + """ + Idefics2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Idefics2PerceiverAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.layer_idx = None + self.hidden_size = config.text_config.hidden_size + self.num_heads = config.perceiver_config.resampler_n_heads + self.head_size = config.perceiver_config.resampler_head_dim + self.num_key_value_heads = config.perceiver_config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.attention_dropout = config.perceiver_config.attention_dropout + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.kv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False + ) + + self.is_causal = False + + def forward( + self, + latents: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = latents.size() + kv_seq_len = q_len + context.size()[1] + + hidden_states = torch.concat([context, latents], dim=-2) + query_states = self.q_proj(latents) + kv = self.kv(hidden_states) + key_states, value_states = kv.split( + [ + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + ], + dim=2, + ) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + bsz, kv_seq_len, self.num_key_value_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + bsz, kv_seq_len, self.num_key_value_heads, self.head_size + ).transpose(1, 2) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_size) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Idefics2PerceiverLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.hidden_size = config.text_config.hidden_size + self.n_latents = config.perceiver_config.resampler_n_latents + self.depth = config.perceiver_config.resampler_depth + self.rms_norm_eps = config.text_config.rms_norm_eps + + self.input_latents_norm = Idefics2RMSNorm( + prefix=f"{prefix}.input_latents_norm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.input_context_norm = Idefics2RMSNorm( + prefix=f"{prefix}.input_context_norm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.self_attn = Idefics2PerceiverAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.post_attention_layernorm = Idefics2RMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=self.rms_norm_eps, + ) + self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + def forward( + self, + latents: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + Args: + latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + """ + residual = latents + + latents = self.input_latents_norm(latents) + context = self.input_context_norm(context) + + latents = self.self_attn( + latents=latents, + context=context, + attention_mask=attention_mask, + ) + latents = residual + latents + residual = latents + + latents = self.post_attention_layernorm(latents) + latents = self.mlp(latents) + latents = residual + latents + + return latents + + +class Idefics2PerceiverResampler(nn.Module): + def __init__(self, prefix, config, weights) -> None: + super().__init__() + self.hidden_size = config.text_config.hidden_size + self.hidden_act = config.perceiver_config.hidden_act + self.n_latents = config.perceiver_config.resampler_n_latents + self.depth = config.perceiver_config.resampler_depth + self.rms_norm_eps = config.text_config.rms_norm_eps + + # Create Latents for Perceiver + self.latents = weights.get_tensor(f"{prefix}.latents") + + # Create Transformer Blocks + self.layers = nn.ModuleList( + [ + Idefics2PerceiverLayer( + prefix=f"{prefix}.layers.{idx}", config=config, weights=weights + ) + for idx in range(self.depth) + ] + ) + self.norm = Idefics2RMSNorm( + prefix=f"{prefix}.norm", + weights=weights, + eps=config.text_config.rms_norm_eps, + ) + + def forward( + self, + context: torch.Tensor, + attention_mask, + ) -> torch.Tensor: + # seq embed -> bsz seq embed + latents = self.latents.unsqueeze(0).expand( + (context.shape[0], *self.latents.size()) + ) + + latent_attention_mask = torch.ones( + (attention_mask.size(0), latents.size(1)), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) + attention_mask = _prepare_4d_attention_mask( + attention_mask, latents.dtype, tgt_len=self.n_latents + ) + + compressed_context = latents + for perceiver_layer in self.layers: + compressed_context = perceiver_layer( + compressed_context, + context, + attention_mask=attention_mask, + ) + compressed_context = self.norm(compressed_context) + + return compressed_context + + +class Idefics2Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics2MLP( + prefix=f"{prefix}.modality_projection", config=config, weights=weights + ) + self.perceiver_resampler = Idefics2PerceiverResampler( + prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights + ) + + def forward(self, image_hidden_states, attention_mask): + image_hidden_states = self.modality_projection(image_hidden_states) + image_hidden_states = self.perceiver_resampler( + context=image_hidden_states, attention_mask=attention_mask + ) + return image_hidden_states + + +class Idefics2ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_seq_len = config.perceiver_config.resampler_n_latents + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py new file mode 100644 index 000000000..a55658194 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py @@ -0,0 +1,326 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Idefics model configuration""" +import copy + +from transformers import PretrainedConfig + +IDEFICS_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "HuggingFaceM4/idefics-9b": "https://huggingface.co/HuggingFaceM4/idefics-9b/blob/main/config.json", + "HuggingFaceM4/idefics-80b": "https://huggingface.co/HuggingFaceM4/idefics-80b/blob/main/config.json", +} + + +class IdeficsVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an + Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Idefics-9B. + e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`) + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + intermediate_size (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + image_num_channels (`int`, *optional*, defaults to `3`): + Number of image channels. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization + testing). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "idefics" + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + embed_dim=768, + image_size=224, + intermediate_size=5120, + patch_size=14, + num_hidden_layers=32, + num_attention_heads=16, + num_channels=3, + hidden_act="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + self.embed_dim = embed_dim + self.image_size = image_size + self.intermediate_size = intermediate_size + self.patch_size = patch_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + + super().__init__(**kwargs) + + +class IdeficsPerceiverConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an + Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Idefics-9B. + e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + use_resampler (`bool`, *optional*, defaults to `False`): + Whether or not to use the resampler + resampler_n_latents (`int`, *optional*, defaults to ): + Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). + resampler_depth (`int`, *optional*, defaults to 6): + Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). + resampler_n_heads (`int`, *optional*, defaults to 16): + Number of heads in each Transformer block (for multi-headed self-attention). + resampler_head_dim (`int`, *optional*, defaults to 96): + Dimensionality of each head projection in the Transformer block. + qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`): + Whether or not to use qk layer norms in perceiver + """ + + model_type = "idefics" + + def __init__( + self, + use_resampler=False, + resampler_n_latents=64, + resampler_depth=6, + resampler_n_heads=16, + resampler_head_dim=96, + qk_layer_norms_perceiver=False, + **kwargs, + ): + self.use_resampler = use_resampler + self.resampler_n_latents = resampler_n_latents + self.resampler_depth = resampler_depth + self.resampler_n_heads = resampler_n_heads + self.resampler_head_dim = resampler_head_dim + self.qk_layer_norms_perceiver = qk_layer_norms_perceiver + + super().__init__(**kwargs) + + +class IdeficsConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an + Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Idefics-9B. + e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + additional_vocab_size (`int`, *optional`, defaults to 0): + Additional vocabulary size of the model, typically for the special "" token. Additional vocab tokens + are always trainable whereas regular vocab tokens can be frozen or not. + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~IdeficsModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + alpha_initializer (`str`, *optional*, defaults to `"zeros"`): + Initialization type for the alphas. + alphas_initializer_range (`float`, *optional*, defaults to 0.0): + The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross + Attention. + alpha_type (`str`, *optional*, defaults to `"float"`): + Whether the gating alphas should be vectors or single floats. + rms_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0) + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1) + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2) + End of stream token id. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + cross_layer_interval (`int`, *optional*, default to 1) + Interval for cross attention (from text to image) layers. + qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k + freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers + freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`): + Exceptions to freezing text layers when `freeze_text_layers` is `True` + freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head + freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers + freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`): + Exceptions to freezing vision layers when `freeze_vision_layers` is `True` + use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler + vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict + perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict + Example: + ```python + >>> from transformers import IdeficsModel, IdeficsConfig + >>> # Initializing a Idefics idefics-9b style configuration + >>> configuration = IdeficsConfig() + >>> # Initializing a model from the idefics-9b style configuration + >>> model = IdeficsModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "idefics" + is_composition = True + + def __init__( + self, + vocab_size=32000, + additional_vocab_size=0, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + dropout=0.0, + hidden_act="silu", + initializer_range=0.02, + alpha_initializer="zeros", + alphas_initializer_range=0.0, + alpha_type="float", + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + cross_layer_interval=1, + qk_layer_norms=False, + freeze_text_layers=True, + freeze_text_module_exceptions=[], + freeze_lm_head=False, + freeze_vision_layers=True, + freeze_vision_module_exceptions=[], + use_resampler=False, + vision_config=None, + perceiver_config=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.additional_vocab_size = additional_vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.alpha_initializer = alpha_initializer + self.alphas_initializer_range = alphas_initializer_range + self.alpha_type = alpha_type + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.cross_layer_interval = cross_layer_interval + self.qk_layer_norms = qk_layer_norms + self.freeze_vision_layers = freeze_vision_layers + + self.freeze_text_layers = freeze_text_layers + self.freeze_text_module_exceptions = freeze_text_module_exceptions + self.freeze_vision_module_exceptions = freeze_vision_module_exceptions + self.freeze_lm_head = freeze_lm_head + + self.use_resampler = use_resampler + + if perceiver_config is None: + self.perceiver_config = IdeficsPerceiverConfig() + elif isinstance(perceiver_config, dict): + self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config) + elif isinstance(perceiver_config, IdeficsPerceiverConfig): + self.perceiver_config = perceiver_config + + if vision_config is None: + self.vision_config = IdeficsVisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = IdeficsVisionConfig(**vision_config) + elif isinstance(vision_config, IdeficsVisionConfig): + self.vision_config = vision_config + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since + # PretrainedConfig.from_dict first instantiates the class with the config dict and only then + # updates the config object with `kwargs` from from_pretrained, so during the instantiation + # of this object many attributes have default values and haven't yet been overridden. + # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run. + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["vision_config"] = self.vision_config.to_dict() + output["perceiver_config"] = self.perceiver_config.to_dict() + output["model_type"] = self.__class__.model_type + + return output diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py new file mode 100644 index 000000000..afb8e1f9c --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py @@ -0,0 +1,297 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Idefics.""" + +from typing import Callable, Dict, List, Optional, Union, Iterable +import numpy as np + +from PIL import Image + +import transformers +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_transforms import ( + resize, + to_channel_dimension_format, + rescale, + normalize, +) +from transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_list_of_images, + to_numpy_array, + valid_images, +) +from io import BytesIO +import base64 +import requests +from transformers import TensorType, is_torch_available + + +IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073] +IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711] + + +def convert_to_rgb(image): + # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background + # for transparent images. The call to `alpha_composite` handles this case + if image.mode == "RGB": + return image + + image_rgba = image.convert("RGBA") + background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) + alpha_composite = Image.alpha_composite(background, image_rgba) + alpha_composite = alpha_composite.convert("RGB") + return alpha_composite + + +class IdeficsImageProcessor(BaseImageProcessor): + r""" + Constructs a Idefics image processor. + Args: + image_size (`int`, *optional*, defaults to `224`): + Resize to image size + image_num_channels (`int`, *optional*, defaults to `3`): + Number of image channels. + image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int = 224, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + image_num_channels: Optional[int] = 3, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.image_size = image_size + self.image_num_channels = image_num_channels + self.image_mean = image_mean + self.image_std = image_std + + def preprocess( + self, + images: ImageInput, + image_num_channels: Optional[int] = 3, + image_size: Optional[Dict[str, int]] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + transform: Callable = None, + **kwargs, + ) -> TensorType.PYTORCH: + """ + Preprocess a batch of images. + Args: + images (`ImageInput`): + A list of images to preprocess. + image_size (`int`, *optional*, defaults to `self.image_size`): + Resize to image size + image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`): + Number of image channels. + image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can + be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` + method. Can be overridden by the `image_std` parameter in the `preprocess` method. + transform (`Callable`, *optional*, defaults to `None`): + A custom transform function that accepts a single image can be passed for training. For example, + `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is + assumed - and then a preset of inference-specific transforms will be applied to the images + Returns: + a PyTorch tensor of the processed images + """ + image_size = image_size if image_size is not None else self.image_size + image_num_channels = ( + image_num_channels + if image_num_channels is not None + else self.image_num_channels + ) + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + size = (image_size, image_size) + + if len(images) == 0: + return [] + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # For training a user needs to pass their own set of transforms as a Callable. + # For reference this is what was used in the original IDEFICS training: + # transform = transforms.Compose([ + # convert_to_rgb, + # transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), + # transforms.ToTensor(), + # transforms.Normalize(mean=image_mean, std=image_std), + # ]) + if transform is not None: + if not is_torch_available(): + raise ImportError("To pass in `transform` torch must be installed") + import torch + + images = [transform(x) for x in images] + return torch.stack(images) + + # for inference we do the exact transforms that were used to train IDEFICS + images = [convert_to_rgb(x) for x in images] + # further transforms expect numpy arrays + images = [to_numpy_array(x) for x in images] + images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images] + images = [self.rescale(image=image, scale=1 / 255) for image in images] + images = [self.normalize(x, mean=image_mean, std=image_std) for x in images] + images = [ + to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images + ] + # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available + images = BatchFeature( + data={"pixel_values": images}, tensor_type=TensorType.PYTORCH + )["pixel_values"] + + return images + + def fetch_images(self, image_url_or_urls: Union[str, List[str]]): + """ + Convert a single or a list of urls into the corresponding `PIL.Image` objects. + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + headers = { + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0" + " Safari/537.36" + ) + } + if isinstance(image_url_or_urls, list): + return [self.fetch_images(x) for x in image_url_or_urls] + elif isinstance(image_url_or_urls, str): + image = image_url_or_urls + + if image.startswith("http://") or image.startswith("https://"): + response = requests.get( + image_url_or_urls, stream=True, headers=headers, timeout=(1, 5) + ) + response.raise_for_status() + content = response.content + elif image.startswith("data:"): + # https://stackoverflow.com/questions/17090571/is-there-a-way-to-set-background-image-as-a-base64-encoded-image + # data:image/png;base64,xxx + image = image.split(",")[-1] + content = base64.b64decode(image) + else: + raise ValueError(f"Unrecognized image {image}") + + try: + image = Image.open(BytesIO(content)) + # image.verify() + except Exception: + raise ValueError(f"Could not load image from url {image_url_or_urls}") + return image + else: + raise ValueError( + f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}" + ) + + def rescale( + self, + image: np.ndarray, + scale: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Rescale an image by a scale factor. image = image * scale. + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The rescaled image. + """ + # return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs) + # requires 4.32 + return rescale(image, scale=scale, data_format=data_format, **kwargs) + + def normalize( + self, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + Args: + image (`np.ndarray`): + Image to normalize. + mean (`float` or `Iterable[float]`): + Image mean to use for normalization. + std (`float` or `Iterable[float]`): + Image standard deviation to use for normalization. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The normalized image. + """ + # TODO 4.32 + return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) + + +transformers.IdeficsImageProcessor = IdeficsImageProcessor diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py new file mode 100644 index 000000000..fc6becc4b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -0,0 +1,1556 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics model.""" +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + dataclass, +) +from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig +from text_generation_server.models.custom_modeling.idefics_vision import ( + IdeficsVisionTransformer, +) +from text_generation_server.models.custom_modeling.idefics_perceiver import ( + IdeficsPerceiverResampler, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + SpeculativeHead, + FastLinear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.utils.import_utils import SYSTEM +from loguru import logger + +if SYSTEM == "cuda": + import dropout_layer_norm +elif SYSTEM == "rocm": + from vllm._C import ops +else: + dropout_layer_norm = None + + +@dataclass +class BaseModelOutputWithPastImage(BaseModelOutputWithPast): + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +class CausalLMOutputWithPastImage(CausalLMOutputWithPast): + image_hidden_states: Optional[torch.FloatTensor] = None + + +# logger = logging.get_logger(__name__) + +# _CONFIG_FOR_DOC = "IdeficsConfig" + +# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [ +# "HuggingFaceM4/idefics-9b", +# "HuggingFaceM4/idefics-80b", +# # See all Idefics models at https://huggingface.co/models?filter=idefics +# ] + + +def expand_inputs_for_generation( + input_ids, + expand_size=1, + is_encoder_decoder=False, + attention_mask=None, + encoder_outputs=None, + **model_kwargs, +): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]) + .view(-1, 1) + .repeat(1, expand_size) + .view(-1) + .to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select( + 0, expanded_return_idx + ) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select( + 0, expanded_return_idx + ) + model_kwargs["image_attention_mask"] = model_kwargs[ + "image_attention_mask" + ].index_select(0, expanded_return_idx) + model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select( + 0, expanded_return_idx + ) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + encoder_outputs["last_hidden_state"] = ( + encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) + ) + ) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + + +def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # must have this key set to at least None + model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) + + # update past + if "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past"] = outputs.past_buckets_states + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) + + # update attention masks + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, + ) + if "image_attention_mask" in model_kwargs: + image_attention_mask = model_kwargs["image_attention_mask"] + last_mask = image_attention_mask[:, -1, :].unsqueeze(1) + model_kwargs["image_attention_mask"] = last_mask + + return model_kwargs + + +def prepare_inputs_for_generation(input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + + pixel_values = kwargs.get("pixel_values", None) + image_attention_mask = kwargs.get("image_attention_mask", None) + # if pixel_values is None or image_attention_mask is None: + # raise ValueError("pixel values and image attention mask cannot be None") + + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "pixel_values": pixel_values, + "image_attention_mask": image_attention_mask, + } + + +def freeze_model(model, module_exceptions=[]): + mapping = { + "LayerNorm": nn.LayerNorm, + "Linear": nn.Linear, + "Embedding": nn.Embedding, + } + module_exceptions_mapped = [mapping[m] for m in module_exceptions] + for module in model.modules(): + if module_exceptions and any( + [isinstance(module, t) for t in module_exceptions_mapped] + ): + module.requires_grad_( + True + ) # Explicitely setting it to true to avoid any mistakes + else: + module.requires_grad_(False) + return model + + +class IdeficsDecoupledPartialTPEmbedding(nn.Module): + def __init__( + self, + config, + weights, + ): + super().__init__() + self.num_embeddings = config.vocab_size + self.weight = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.additional_weight = nn.Parameter( + weights.get_tensor("model.embed_tokens.additional_embedding.weight") + ) + + def forward(self, input_ids): + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = torch.nn.functional.embedding( + input_ids_additional_vocab - self.num_embeddings, self.additional_weight + ) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = self.weight(input_ids) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + +class IdeficsDecoupledTensorParallelLinear(nn.Module): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. + """ + + def __init__( + self, + config, + weights, + ) -> None: + super().__init__() + self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights) + self.additional_fc = FastLinear.load( + config=config, + prefix="lm_head.additional_fc", + weights=weights, + bias=False, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output, speculative_logits = self.fc(input) + additional_features = self.additional_fc(input) + output = torch.cat((output, additional_features), -1) + + return output, speculative_logits + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( + self.in_features, + self.out_features, + self.out_additional_features, + self.bias is not None, + self.partially_freeze, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class IdeficsRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + residual is not None, + ) + return out + elif hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + elif SYSTEM == "cuda": + # faster post attention rms norm + unwrap = False + if len(hidden_states.shape) > 2: + unwrap = True + shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, shape[-1]) + + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + if unwrap: + normed_hidden_states = normed_hidden_states.view(*shape) + + return normed_hidden_states + elif SYSTEM == "rocm": + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + unwrap = False + if len(hidden_states.shape) > 2: + unwrap = True + shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, shape[-1]) + + out = torch.empty_like(hidden_states) + ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + + if unwrap: + out = out.view(*shape) + + return out + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + +# this was adapted from LlamaMLP +class IdeficsMLP(nn.Module): + def __init__( + self, + config, + prefix, + weights, + ): + super().__init__() + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + shape = gate_up_states.shape + gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2) + return self.down_proj( + self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1] + ) + + +# this was adapted from LlamaAttention +class IdeficsAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config, + prefix, + weights, + qk_layer_norms: bool = False, + is_cross_attention: bool = False, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.dropout = config.dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.is_cross_attention = is_cross_attention + + # if not hasattr(nn.functional, "scaled_dot_product_attention"): + # raise ValueError("this model requires pytorch 2.0 or higher") + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads //= weights.process_group.size() + + if self.is_cross_attention: + # kv_input_dim = ( + # self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim + # ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=False + ) + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=False + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=False + ) + else: + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.o_proj", weights=weights, bias=False + ) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=self.head_dim, base=10000.0, device=weights.device + ) + self.qk_layer_norms = qk_layer_norms + if self.qk_layer_norms: + self.q_layer_norm = IdeficsRMSNorm( + prefix=f"{prefix}.q_layer_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.k_layer_norm = IdeficsRMSNorm( + prefix=f"{prefix}.q_layer_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = self.is_cross_attention or key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + if is_cross_attention: + query_states = self.q_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim + ) # .transpose(1, 2) + query_states = query_states.transpose(1, 2) + ( + _, + kv_len, + _, + ) = ( + key_value_states.size() + ) # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = ( + self.k_proj(key_value_states) + .view(bsz, kv_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(key_value_states) + .view(bsz, kv_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + else: + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + self.num_heads * self.head_dim, dim=2 + ) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ) # .transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_heads, self.head_dim + ) # . transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_heads, self.head_dim + ) # .transpose(1, 2) + kv_seq_len = q_len + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + max_s = max(kv_seq_len, q_len) + cos, sin = self.rotary_emb.get_cos_sin( + position_ids.view(-1), max_s, hidden_states.dtype + ) + + query_shape = query_states.shape + key_shape = key_states.shape + self.rotary_emb( + query_states.view(-1, *query_shape[2:]), + key_states.reshape(-1, *key_shape[2:]), + cos, + sin, + ) + + query_states = query_states.view(query_shape) + key_states = key_states.view(key_shape) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.qk_layer_norms: + query_states = self.q_layer_norm(query_states) + key_states = self.k_layer_norm(key_states) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_output = nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + attn_weights = None + if output_attentions: + logger.warning_once( + "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead" + ) + + return attn_output, attn_weights, past_key_value + + +# this was adapted from LlamaDecoderLayer +class IdeficsDecoderLayer(nn.Module): + def __init__(self, layer_id: int, config: IdeficsConfig, weights): + super().__init__() + self.process_group = weights.process_group + self.hidden_size = config.hidden_size + prefix = f"model.layers.{layer_id}" + self.self_attn = IdeficsAttention( + config=config, + prefix=f"{prefix}.self_attn", + weights=weights, + qk_layer_norms=False, + is_cross_attention=False, + ) + self.mlp = IdeficsMLP( + config=config, + prefix=f"{prefix}.mlp", + weights=weights, + ) + self.input_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.dropout = config.dropout + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class IdeficsGatedCrossAttentionLayer(nn.Module): + def __init__(self, layer_id, config: IdeficsConfig, weights): + super().__init__() + self.process_group = weights.process_group + self.hidden_size = config.hidden_size + prefix = f"model.gated_cross_attn_layers.{layer_id}" + self.cross_attn = IdeficsAttention( + config=config, + prefix=f"{prefix}.cross_attn", + weights=weights, + qk_layer_norms=True, + is_cross_attention=True, + ) + self.mlp = IdeficsMLP( + config=config, + prefix=f"{prefix}.mlp", + weights=weights, + ) + self.input_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.config = config.dropout + + self.act_cross_attn = nn.Tanh() + self.act_dense = nn.Tanh() + + self.alpha_cross_attn = nn.Parameter( + weights.get_tensor(f"{prefix}.alpha_cross_attn") + ) + self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense")) + + if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): + raise ValueError("Alpha parameters not initialized correctly!") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_hidden_states: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + no_images: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored + """ + if image_hidden_states is None: + raise ValueError( + "`image_hidden_states` is required for Idefics cross attention module which are visual features to be" + " conditioned on." + ) + + if past_key_value is not None: + raise NotImplementedError( + "Past key value states are not implemented for Idefics cross attention module." + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=image_hidden_states, + attention_mask=image_attention_mask, + output_attentions=output_attentions, + ) + # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + # when there are no images the model is used in pure language mode + gate = 0 if no_images else 1 + hidden_states = ( + residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`IdeficsConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# @add_start_docstrings( +# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", +# LLAMA_START_DOCSTRING, +# ) +class IdeficsPreTrainedModel(PreTrainedModel): + config_class = IdeficsConfig + # base_model_prefix = "model" + # supports_gradient_checkpointing = True + # _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] + + # def _init_weights(self, module): + # # important: this ported version of Idefics isn't meant for training from scratch - only + # # inference and fine-tuning - so the proper init weights code has been removed - the m4 code + # # base should be used for training from scratch and it contains the correct code. + # std = self.config.initializer_range + # if isinstance(module, nn.Linear): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.bias is not None: + # module.bias.data.zero_() + # elif isinstance(module, nn.Embedding): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.padding_idx is not None: + # module.weight.data[module.padding_idx].zero_() + + # def _set_gradient_checkpointing(self, module, value=False): + # if isinstance(module, IdeficsModel): + # module.gradient_checkpointing = value + + +# LLAMA_INPUTS_DOCSTRING = r""" +# Args: +# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): +# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide +# it. + +# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and +# [`PreTrainedTokenizer.__call__`] for details. + +# [What are input IDs?](../glossary#input-ids) +# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): +# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + +# - 1 for tokens that are **not masked**, +# - 0 for tokens that are **masked**. + +# [What are attention masks?](../glossary#attention-mask) + +# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and +# [`PreTrainedTokenizer.__call__`] for details. + +# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see +# `past_key_values`). + +# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] +# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more +# information on the default strategy. + +# - 1 indicates the head is **not masked**, +# - 0 indicates the head is **masked**. +# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): +# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, +# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) +# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): +# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape +# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape +# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + +# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention +# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + +# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that +# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all +# `decoder_input_ids` of shape `(batch_size, sequence_length)`. +# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): +# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This +# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the +# model's internal embedding lookup matrix. +# use_cache (`bool`, *optional*): +# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see +# `past_key_values`). +# output_attentions (`bool`, *optional*): +# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned +# tensors for more detail. +# output_hidden_states (`bool`, *optional*): +# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for +# more detail. +# return_dict (`bool`, *optional*): +# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +# """ + + +# @add_start_docstrings( +# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", +# LLAMA_START_DOCSTRING, +# ) +class IdeficsModel(IdeficsPreTrainedModel): + # """ + # Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`] + + # Args: + # config: IdeficsConfig + # """ + + def __init__(self, config: IdeficsConfig, weights): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = IdeficsDecoupledPartialTPEmbedding( + config=config, + weights=weights, + ) + + self.image_size = config.vision_config.image_size + self.vision_config = config.vision_config + self.vision_model = IdeficsVisionTransformer( + prefix="model.vision_model", + config=config.vision_config, + weights=weights, + ) + + # Perceiver Resampler + if config.use_resampler: + perceiver_config = config.perceiver_config + self.perceiver_resampler = IdeficsPerceiverResampler( + prefix="model.perceiver_resampler", + config=config, + embed_dim=config.vision_config.embed_dim, + depth=perceiver_config.resampler_depth, + n_heads=perceiver_config.resampler_n_heads, + head_dim=perceiver_config.resampler_head_dim, + n_latents=perceiver_config.resampler_n_latents, + weights=weights, + ) + + self.layers = nn.ModuleList( + [ + IdeficsDecoderLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.cross_layer_interval = config.cross_layer_interval + num_cross_layers = config.num_hidden_layers // self.cross_layer_interval + self.gated_cross_attn_layers = nn.ModuleList( + [ + IdeficsGatedCrossAttentionLayer(layer_id, config, weights) + for layer_id in range(num_cross_layers) + ] + ) + # self.gradient_checkpointing = False + + self.norm = IdeficsRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + # self.gradient_checkpointing = False + # Initialize weights and apply final processing + # self.post_init() + + # self.freeze_relevant_params(config) + + # def freeze_relevant_params(self, config=None): + # if config is None: + # config = self.config + + # if config.freeze_text_layers: + # self.freeze_text_layers(config.freeze_text_module_exceptions) + + # if config.freeze_vision_layers: + # freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) + + # def freeze_text_layers(self, module_exceptions=[]): + # for module in [self.layers, self.norm]: + # freeze_model(module, module_exceptions=module_exceptions) + + # def freeze_vision_layers(self, module_exceptions=[]): + # freeze_model(self.vision_model, module_exceptions=module_exceptions) + + # def get_input_embeddings(self): + # return self.embed_tokens + + # def set_input_embeddings(self, value): + # self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastImage]: + device = input_ids.device if input_ids is not None else inputs_embeds.device + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + elif position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + no_images = False + + if image_hidden_states is None: + if pixel_values is None and image_embeddings is None: + raise ValueError( + "Either pixel_values and image_embeddings have to be not-None." + ) + + elif pixel_values is not None and image_embeddings is not None: + raise ValueError( + "You cannot specify both pixel_values and image_embeddings at the same time" + ) + + elif pixel_values is not None: + no_images = len(torch.nonzero(pixel_values)) == 0 + pixel_values = pixel_values.to( + dtype=self.dtype, device=device + ) # fp16 compatibility + batch_size, num_images = pixel_values.shape[:2] + pixel_values = pixel_values.contiguous().view( + batch_size * num_images, *pixel_values.shape[2:] + ) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values + ).last_hidden_state + + elif image_embeddings is not None: + ( + batch_size, + num_images, + image_seq_len, + image_hidden_size, + ) = image_embeddings.size() + image_hidden_states = image_embeddings.to( + dtype=self.dtype, device=input_ids.device + ) + image_hidden_states = image_hidden_states.view( + batch_size * num_images, image_seq_len, image_hidden_size + ) + + if self.config.use_resampler: + image_hidden_states = self.perceiver_resampler(image_hidden_states) + image_seq_len, image_hidden_size = image_hidden_states.size( + 1 + ), image_hidden_states.size(2) + image_hidden_states = image_hidden_states.view( + batch_size, num_images * image_seq_len, image_hidden_size + ) + else: + no_images = False + num_images = pixel_values.shape[1] + image_seq_len = image_hidden_states.shape[1] // num_images + + # # Hack to use the model in full language modeling mode + # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device) + # Make image_attention_mask compatible with hidden states + text_seq_len = image_attention_mask.size(1) + image_attention_mask = image_attention_mask.unsqueeze(-1) + image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len) + image_attention_mask = image_attention_mask.view( + batch_size, text_seq_len, num_images * image_seq_len + ) + image_batch_size, image_sequence_length, _ = image_hidden_states.size() + image_hidden_shape = (image_batch_size, image_sequence_length) + if image_attention_mask is None: + image_attention_mask = torch.ones(image_hidden_shape, device=device) + image_attention_mask = self.invert_attention_mask(image_attention_mask) + + # if list(image_attention_mask.shape) != [4, 1, 1024, 64]: + # raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}") + + # if image_hidden_states is not None: + # else: + # image_attention_mask = None + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + # if self.gradient_checkpointing and self.training: + # if use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + # use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + def vblock( + main_block, + hidden_states, + attention_mask, + position_ids, + past_key_value, + image_hidden_states, + image_attention_mask, + output_attentions, + use_cache, + no_images, + layer_idx, + cross_layer_interval, + gated_cross_attn_layers, + ): + # TODO(ls): Add cross attention values to respective lists + if layer_idx % cross_layer_interval == 0: + xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] + outputs = xblock( + hidden_states, + attention_mask=attention_mask, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + output_attentions=output_attentions, + use_cache=use_cache, + past_key_value=None, # not implemented + no_images=no_images, + ) + hidden_states = outputs[0] + + layer_outputs = main_block( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + return layer_outputs + + # if self.gradient_checkpointing and self.training: + # past_key_value = None + # if use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + # use_cache = False + + # layer_outputs = torch.utils.checkpoint.checkpoint( + # vblock, + # decoder_layer, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_value, + # image_hidden_states, + # image_attention_mask, + # output_attentions, + # use_cache, + # no_images, + # idx, + # self.cross_layer_interval, + # self.gated_cross_attn_layers, + # ) + # else: + layer_outputs = vblock( + decoder_layer, + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + output_attentions=output_attentions, + use_cache=use_cache, + no_images=no_images, + layer_idx=idx, + cross_layer_interval=self.cross_layer_interval, + gated_cross_attn_layers=self.gated_cross_attn_layers, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPastImage( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + image_hidden_states=image_hidden_states, + ) + + +class IdeficsForVisionText2Text(IdeficsPreTrainedModel): + def __init__( + self, + config, + weights, + ): + super().__init__(config) + self.model = IdeficsModel( + config=config, + weights=weights, + ) + + self.lm_head = IdeficsDecoupledTensorParallelLinear( + config=config, + weights=weights, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPastImage]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_embeddings=image_embeddings, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits, speculative_logits = self.lm_head(hidden_states) + + loss = None + + return ( + CausalLMOutputWithPastImage( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ), + speculative_logits, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) + unwanted_kwargs = ["token_type_ids"] + for kwarg in unwanted_kwargs: + inputs.pop(kwarg, None) + return inputs + + @staticmethod + def _expand_inputs_for_generation( + *args, + **model_kwargs, + ): + return expand_inputs_for_generation(*args, **model_kwargs) + + @staticmethod + def _update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False + ): + return update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py new file mode 100644 index 000000000..6da8045bc --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py @@ -0,0 +1,276 @@ +# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License. +# +# MIT License +# +# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +""" + +Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially +time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note +that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to +prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that +to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore. + +References: + - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model + - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch + +""" +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + +EPS = 1e-5 + + +class IdeficsPerceiverResampler(nn.Module): + def __init__( + self, + prefix, + config, + embed_dim: int, + depth: int, + n_heads: int, + head_dim: int, + n_latents: int, + weights, + ) -> None: + """ + Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or + MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then + returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed + to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. + Could be e.g., VIT embed_dim, ResNet pool dim, and so on. + + Args: + config (`IdeficsConfig`): config object + embed_dim (`int`): The size of each embedding vector + depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). + n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention). + head_dim (`int`): Dimensionality of each head projection in the Transformer block. + n_latents (`int`): + Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). + + """ + super().__init__() + self.embed_dim, self.n_heads, self.head_dim, self.n_latents = ( + embed_dim, + n_heads, + head_dim, + n_latents, + ) + self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver + + # Create Latents for Perceiver + self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents")) + + self.intermediate_dim = ( + self.embed_dim * 4 + if not hasattr(config.vision_config, "embed_dim") + else config.vision_config.embed_dim * 4 + ) + # Create Transformer Blocks + self.blocks = nn.ModuleList( + [ + nn.ModuleList( + [ + IdeficsPerceiverAttention( + prefix=f"{prefix}.blocks.{layer_id}.0", + config=config, + embed_dim=self.embed_dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + qk_layer_norms=self.qk_layer_norms, + weights=weights, + ), + IdeficsMLP( + prefix=f"{prefix}.blocks.{layer_id}.1", + intermediate_size=self.intermediate_dim, + config=config, + weights=weights, + ), + ] + ) + for layer_id in range(depth) + ] + ) + self.layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS + ) + + def forward(self, context: torch.Tensor) -> torch.Tensor: + """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" + # einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) + latents = self.latents.repeat(context.shape[0], 1, 1) + + # Feed through Perceiver Attention blocks... + for attn, ff in self.blocks: + latents = attn(context, latents) + latents + latents = ff(latents) + latents + + return self.layer_norm(latents) + + +class IdeficsPerceiverAttention(nn.Module): + def __init__( + self, + prefix, + config, + embed_dim: int, + n_heads: int, + head_dim: int, + qk_layer_norms: bool, + weights, + ) -> None: + """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" + super().__init__() + self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim + self.qk_layer_norms = qk_layer_norms + # Normalization & Scaling + self.context_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS + ) + self.latents_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS + ) + if self.qk_layer_norms: + self.q_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS + ) + self.k_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS + ) + + self.qk_scale = self.head_dim**-0.5 + + if n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.n_heads //= weights.process_group.size() + + # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). + self.q_proj = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False + ) + self.k_proj = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False + ) + self.v_proj = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False + ) + + self.output_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False + ) + + def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """ + Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! + + Args: + context (`torch.Tensor`): + Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample. + latents (`torch.Tensor`): + Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to. + + Returns: + `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross + from context. + """ + context = self.context_layer_norm(context) + latents = self.latents_layer_norm(latents) + batch_size, seq_length, embed_dim = context.shape[:3] + + # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! + # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` + q = self.q_proj(latents) + k = self.k_proj(torch.cat([context, latents], dim=-2)) + v = self.v_proj(torch.cat([context, latents], dim=-2)) + + # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) + # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] + # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) + q, k, v = [ + x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose( + 1, 2 + ) + for x in (q, k, v) + ] + + if self.qk_layer_norms: + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + + scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k) + stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach()) + attn = stabilized_scores.softmax(dim=-1) + + # Attend & project back to output... + resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v) + # einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads) + return self.output_proj(resampled.transpose(1, 2).flatten(-2)) + + +class IdeficsMLP(nn.Module): + def __init__( + self, + prefix, + intermediate_size, + config, + weights, + ): + """Simple MLP block with intermediate_size and embedding size""" + super().__init__() + self.embed_dim = config.vision_config.embed_dim + self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS) + self.fc = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.fc", + weights=weights, + bias=False, + ) + self.act = nn.ReLU() + self.c_proj = TensorParallelRowLinear.load( + config=config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=False, + ) + + def forward( + self, hidden_states: Optional[Tuple[torch.FloatTensor]] + ) -> torch.FloatTensor: + hidden_states = self.ln(hidden_states) + hidden_states = self.fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py new file mode 100644 index 000000000..ca61e27d4 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py @@ -0,0 +1,443 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for IDEFICS. +""" + +from typing import Callable, List, Optional, Union +from urllib.parse import urlparse + +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import ( + BatchEncoding, + PaddingStrategy, + TextInput, + TruncationStrategy, +) +from transformers.utils import TensorType, is_torch_available + + +if is_torch_available(): + import torch + + +IMAGE_TOKEN = "" + + +# copied from m4.training.packing +def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): + # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] + + # If any of images index are more than num_classes, set them to -1. + # Words after the max number of images allowed have been seen don't attend on anything + if num_classes != -1: + incremental_mask[incremental_mask >= num_classes] = -1 + + negatives = incremental_mask == -1 + incremental_mask[negatives] = 0 + attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) + attn_mask[negatives, :] = 0 + return attn_mask + + +# copied from m4.training.packing +def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): + image_attention_mask = torch.full_like(input_ids, fill_value=-1) + next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + eod_token_id = tokenizer.eos_token_id + for batch_idx in range(input_ids.size(0)): + count = -1 + seen_eod = False + for idx, token_id in enumerate(input_ids[batch_idx]): + if token_id == image_token_id: + count += 1 + image_attention_mask[batch_idx][idx] = count + seen_eod = False + else: + image_attention_mask[batch_idx][idx] = count + + if seen_eod: + image_attention_mask[batch_idx][idx] = -1 + + if token_id == eod_token_id: + seen_eod = True + + for batch_idx in range(input_ids.size(0)): + count = -1 + seen_eod = False + for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1): + token_id = input_ids[batch_idx][idx] + if token_id == image_token_id: + count += 1 + next_image_attention_mask[batch_idx][idx] = count + seen_eod = False + else: + next_image_attention_mask[batch_idx][idx] = count + + if token_id == eod_token_id: + seen_eod = True + + if seen_eod: + next_image_attention_mask[batch_idx][idx] = -1 + + non_negative_indices = next_image_attention_mask[batch_idx] != -1 + next_image_attention_mask[batch_idx][non_negative_indices] -= count + next_image_attention_mask[batch_idx][non_negative_indices] *= -1 + + return image_attention_mask, next_image_attention_mask + + +def is_url(string): + """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately + invalidated the url""" + if " " in string: + return False + result = urlparse(string) + return all([result.scheme, result.netloc]) + + +def is_image(string): + """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately + invalidated the url""" + return is_url(string) or string.startswith("data:") + + +class IdeficsProcessor(ProcessorMixin): + r""" + Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor. + + [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See + the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information. + + Args: + image_processor (`IdeficsImageProcessor`): + An instance of [`IdeficsImageProcessor`]. The image processor is a required input. + tokenizer (`LlamaTokenizerFast`): + An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. + image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image) + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "IdeficsImageProcessor" + tokenizer_class = "LlamaTokenizerFast" + + def __init__( + self, + image_processor, + tokenizer=None, + image_size=224, + add_end_of_utterance_token=None, + **kwargs, + ): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + + self.default_image_dims = ( + self.image_processor.image_num_channels, + self.image_processor.image_size, + self.image_processor.image_size, + ) + + self.tokenizer_was_trained_with_end_of_utterance_token = ( + True + if "" + in self.tokenizer.special_tokens_map.get("additional_special_tokens", []) + else False + ) + + def __call__( + self, + prompts: Union[List[TextInput], List[List[TextInput]]], + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + transform: Callable = None, + add_eos_token=False, + add_end_of_utterance_token=None, + debug=False, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchEncoding: + """This method takes batched or non-batched prompts made of text and images and converts them into prompts that + the model was trained on and prepares the image pixel values for the model to process. + + Args: + prompts (`Union[List[TextInput], [List[List[TextInput]]]]`): + either a single prompt or a batched list of prompts - see the detailed description immediately after + the end of the arguments doc section. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + transform (`Callable`, *optional*): + A custom transform function that accepts a single image can be passed for training. For example, + `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific + set of transforms will be applied to the images + add_eos_token (`bool`, *optional*, defaults to `False`): + Adds `eos_token` at the end of the final prompt if True` + add_end_of_utterance_token (`bool`, *optional*) + Whether to automatically add `` after each prompt's text input (unless followed by an + image). If `None` the tokenizer will be checked instead and if this token is found in + `additional_special_tokens` then the value will be `True`. + debug (`bool`, *optional*, defaults to `False`): + `True` value will help debug prompt generation by dumping useful information + return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`): + The type of tensors to return. Can be one of: + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + + Returns: + a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be + directly passed to `model.generate` + + Detailed explanation: + + Each entry in `prompts` is either a text to be passed as is or an image that will be processed. + + An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved. + + When the processor encounters an image it'll inject `` + entry into the prompt. + + Example: + + ```python + checkpoint = "HuggingFaceM4/idefics-9b" + processor = AutoProcessor.from_pretrained(checkpoint) + url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg" + img = processor.image_processor.fetch_images([url])[0] + + prompts = [ + "User:", + img, + "Describe this image.\nAssistant: An image of two kittens in grass.\n", + "User:", + "https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg", + "Describe this image.\nAssistant:", + ] + + inputs = processor(prompts, return_tensors="pt") + generated_ids = model.generate(**inputs, max_length=100) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ``` + + In this example the `prompts` will be converted into: + + ``` + User:Describe this image. + Assistant: An image of two kittens in grass. + User:Describe this image. + Assistant:' + ``` + + and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the + `pixel_values` dict entry of the return value. + + This example also examplifies that images can be passed as objects or as text urls. It can be seen that the + first image is passed as object and the second one as a url. + + To do training do: + + ```python + image_transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=self.image_mean, std=self.image_std), + ] + ) + inputs = processor(prompts, transform=image_transform, return_tensors="pt") + ``` + + In order to help debug prompt generation enable `debug=True` which will show you what's happening. + + """ + + # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it + if add_end_of_utterance_token is None: + add_end_of_utterance_token = ( + self.tokenizer_was_trained_with_end_of_utterance_token + ) + + # turn non-batched prompts into batched + if not any(isinstance(i, list) for i in prompts): + prompts = [prompts] + + fake_token = "" + image_token = "" + end_of_utterance_token = "" + + def image_tokens(last_was_image): + if last_was_image: + return image_token + fake_token + else: + return fake_token + image_token + fake_token + + all_texts = [] + all_images = [] + for sample in prompts: + # the model was trained on samples starting with + full_text = f"{self.tokenizer.bos_token}" + + # an image can either be an image object in the item or the url, everything else is a verbatim prompt text + image_objects = [] + last_was_image = False + last_was_text = False + for i, item in enumerate(sample): + if i > 0: + last_was_text = True if not last_was_image else False + + if isinstance(item, str): + item = item.strip(" ") + if is_image(item): + image = self.image_processor.fetch_images(item) + full_text += image_tokens(last_was_image) + image_objects.append(image) + last_was_image = True + else: + # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!) + if add_end_of_utterance_token and last_was_text: + full_text += end_of_utterance_token + full_text += item + last_was_image = False + else: + # must be an image obj + full_text += image_tokens(last_was_image) + image_objects.append(item) + last_was_image = True + + if add_eos_token: + full_text += self.tokenizer.eos_token + + if debug is True: + print(f"{full_text=}") + + image_objects = self.image_processor(image_objects, transform=transform) + + text_encoding = self.tokenizer( + text=full_text, + add_special_tokens=False, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + + all_texts.append(text_encoding["input_ids"]) + all_images.append(image_objects) + + max_seq_len = max(len(x) for x in all_texts) + + # max_num_images has to be at least 1 even when there are no images + max_num_images = max(len(x) for x in all_images) + max_num_images = max(1, max_num_images) + + at_least_one_image = sum(len(x) for x in all_images) > 0 + output_input_ids = [] + output_images = [] + output_attention_masks = [] + for text, images in zip(all_texts, all_images): + padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len + unpadded_seq_len = len(text) + start = max_seq_len - unpadded_seq_len + padded_input_ids[start:] = text[:max_seq_len] + + attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) + attention_mask[start:] = 1 + + image_count = padded_input_ids.count(self.image_token_id) + local_max_num_images = min(image_count, max_num_images) + + current_images = images[:local_max_num_images] + + if len(current_images) > 0: + padded_image_tensor = torch.zeros( + max_num_images, *current_images.size()[1:] + ) + padded_image_tensor[: current_images.size(0)] = current_images + else: + padded_image_tensor = torch.zeros( + max_num_images, *self.default_image_dims + ) + + output_images.append(padded_image_tensor) + output_input_ids.append(torch.tensor(padded_input_ids)) + + output_attention_masks.append(attention_mask) + + output_input_ids = torch.stack(output_input_ids) + output_images = torch.stack(output_images) + output_attention_masks = torch.stack(output_attention_masks) + + if at_least_one_image: + image_attention_mask, _ = image_attention_mask_for_packed_input_ids( + output_input_ids, self.tokenizer + ) + image_attention_mask = incremental_to_binary_attention_mask( + image_attention_mask, num_classes=max_num_images + ) + else: + # in full language mode we set the image mask to all-0s + image_attention_mask = torch.zeros( + output_input_ids.shape[0], + output_input_ids.shape[1], + 1, + dtype=torch.bool, + ) + + return BatchFeature( + data={ + "input_ids": output_input_ids, + "attention_mask": output_attention_masks, + "pixel_values": output_images, + "image_attention_mask": image_attention_mask, + } + ) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py new file mode 100644 index 000000000..dd8f76bc4 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -0,0 +1,529 @@ +# coding=utf-8 +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.utils import ( + ModelOutput, + logging, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +@dataclass +class IdeficsVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics +class IdeficsVisionEmbeddings(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.class_embedding") + ) + + self.patch_embedding = nn.Conv2d.load_no_bias( + prefix=f"{prefix}.patch_embedding", + weights=weights, + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = TensorParallelEmbedding( + prefix="model.vision_model.embeddings.position_embedding", weights=weights + ) + self.position_ids = ( + torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device) + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision +class IdeficsVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=True + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=True + ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=True + ) + self.out_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + causal_attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision +class IdeficsVisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc1", weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.fc2", weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision +class IdeficsVisionEncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = IdeficsVisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps + ) + self.mlp = IdeficsVisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision +class IdeficsVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`IdeficsVisionEncoderLayer`]. + + Args: + config: IdeficsVisionConfig + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + IdeficsVisionEncoderLayer( + prefix=f"{prefix}.encoder.layers.{layer_id}", + config=config, + weights=weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + # self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # if self.gradient_checkpointing and self.training: + + # def create_custom_forward(module): + # def custom_forward(*inputs): + # return module(*inputs, output_attentions) + + # return custom_forward + + # layer_outputs = torch.utils.checkpoint.checkpoint( + # create_custom_forward(encoder_layer), + # hidden_states, + # attention_mask, + # causal_attention_mask, + # ) + # else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer +class IdeficsVisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + + self.embeddings = IdeficsVisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.pre_layrnorm = nn.LayerNorm.load( + prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps + ) + self.encoder = IdeficsVisionEncoder( + prefix=prefix, config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py new file mode 100644 index 000000000..1ef550190 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -0,0 +1,296 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Llava-NeXT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.models.llava_next.modeling_llava_next import ( + unpad_image, +) +from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration +from transformers.image_processing_utils import select_best_resolution + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): + + def _merge_input_ids_with_image_features( + self, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + input_ids: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = input_ids == self.config.image_token_index + # Let's pray we have enabled enough slots ! + try: + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + except Exception as e: + raise RuntimeError( + f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ): + + if token_idx is not None: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + logits = outputs[0] + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + **kwargs, + ): + """ + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 + The only differences are: + - add new args token_idx + - add the process of merging images into inputs_embeds + """ + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + **kwargs, + ) + else: + use_flash_attention = kwargs.get("use_flash_attention", False) + flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + + position_ids = kwargs.get("position_ids", None) + labels = kwargs.get("labels", None) + if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: + vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) + vision_feature_layer = kwargs.get("vision_feature_layer", None) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + batch_size, num_patches, num_channels, height, width = pixel_values.shape + reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) + image_features = self.vision_tower( + reshaped_pixel_values, + output_hidden_states=True, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + selected_image_feature = image_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx].tolist(), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) + self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None: + seq_len = input_ids.shape[1] + pad_len = seq_len - token_idx.item() + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = extended_attention_mask + attention_mask[:, -pad_len:] = 0 + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "labels": labels, + "use_flash_attention": use_flash_attention, + "flash_attention_recompute": flash_attention_recompute, + } + ) + + return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py new file mode 100644 index 000000000..293051c2b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -0,0 +1,232 @@ +import torch +import torch.distributed + +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from torch import nn +from typing import Optional, Tuple, Any +from transformers.configuration_utils import PretrainedConfig +import torch.nn.functional as F + +from text_generation_server.layers import ( + SpeculativeHead, + TensorParallelEmbedding, + FastLinear, +) +from text_generation_server.layers.layernorm import FastRMSNorm + +from einops import rearrange +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +import math +from dataclasses import dataclass + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + conv_states: torch.Tensor + ssm_states: torch.Tensor + seqlen_offset: int + + +class MambaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=50280, + d_model=768, + d_state=16, + n_layer=32, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + expand=2, + dt_rank="auto", + **kwargs, + ): + self.vocab_size = vocab_size + self.n_layer = n_layer + self.layer_norm_epsilon = layer_norm_epsilon + self.d_model = d_model + self.d_inner = d_model * 2 + self.d_conv = 4 + self.d_state = d_state + self.expand = expand + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class MambaBlock(nn.Module): + def __init__(self, prefix, config, weights, layer_id): + super().__init__() + self.layer_id = layer_id + self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) + self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) + self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) + self.dt_proj_no_bias = FastLinear.load( + config, f"{prefix}.dt_proj", weights, bias=False + ) + self.out_proj = FastLinear.load( + config, f"{prefix}.out_proj", weights, bias=False + ) + self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) + self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) + self.D = weights.get_tensor(f"{prefix}.D") + self.activation = "silu" + self.dt_rank = config.dt_rank + self.d_state = config.d_state + self.d_conv = config.d_conv + self.act = nn.SiLU() + + # inference_params + def forward(self, hidden_states: torch.Tensor, inference_params=None): + if inference_params.seqlen_offset > 0: + conv_state = inference_params.conv_states[self.layer_id] + ssm_state = inference_params.ssm_states[self.layer_id] + out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) + return out, conv_state, ssm_state + + _, seqlen, _ = hidden_states.shape + projected_states = self.in_proj(hidden_states).transpose(1, 2) + # assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]" + x, z = projected_states.chunk(2, dim=1) + conv_state = F.pad(x, (self.d_conv - seqlen, 0)) + x = causal_conv1d_fn( + x=x, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split( + x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 + ) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + y, last_state = selective_scan_fn( + x, + dt, + self.negA, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=True, + ) + y = rearrange(y, "b d l -> b l d") + attn_outputs = self.out_proj(y) + return attn_outputs, conv_state, last_state + + def step(self, hidden_states, conv_state, ssm_state): + xz = self.in_proj(hidden_states.squeeze(1)) + x, z = xz.chunk(2, dim=-1) # (B D) + x = causal_conv1d_update( + x, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = F.linear(dt, self.dt_proj.weight) + A = self.negA + y = selective_state_update( + ssm_state, + x, + dt, + A, + B, + C, + self.D, + z=z, + dt_bias=self.dt_proj.bias, + dt_softplus=True, + ) + out = self.out_proj(y) + return out.unsqueeze(1), conv_state.clone(), ssm_state.clone() + + +class ResidualBlock(nn.Module): + def __init__(self, prefix, config, weights, layer_id): + super().__init__() + self.mamba_block = MambaBlock( + prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id + ) + self.layer_norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_epsilon + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = None, + inference_params: Optional[Any] = None, + ): + residual = (hidden_states + residual) if residual is not None else hidden_states + shape = residual.shape + hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) + hidden_states, conv_state, last_ssm_state = self.mamba_block( + hidden_states.view(*shape), inference_params + ) + return hidden_states, residual, conv_state, last_ssm_state + + +class MambaModel(nn.Module): + def __init__(self, config, weights): + super().__init__() + prefix = "backbone" + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + self.blocks = nn.ModuleList( + [ + ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) + for i in range(config.n_layer) + ] + ) + self.norm_f = FastRMSNorm.load( + f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon + ) + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) + self.config = config + + def forward( + self, input_ids: torch.Tensor, inference_params=None, residual=None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.embed_tokens(input_ids) + for i, block in enumerate(self.blocks): + hidden_states, residual, conv_state, ssm_state = block( + hidden_states, residual, inference_params + ) + inference_params.conv_states[i].copy_(conv_state) + inference_params.ssm_states[i].copy_(ssm_state) + + hidden_states = ( + hidden_states + residual if residual is not None else hidden_states + ) + hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) + hidden_states = hidden_states.view(residual.shape) + logits, speculative_logits = self.lm_head(hidden_states) + + # update the offset for the next inference using these params + inference_params.seqlen_offset += input_ids.size(1) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py new file mode 100644 index 000000000..73536bd65 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py @@ -0,0 +1,995 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn +import flash_attn_2_cuda + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + FastLinear, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, +) + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape( + batch_size, max_num_tiles * target_length, 1 + ) + attention_mask = ( + attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave( + num_vision_tokens, dim=3 + ) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value) + .any(dim=-1) + .type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision +class MllamaVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.embed_dim = config.hidden_size + self.head_dim = config.hidden_size // config.attention_heads + self.num_heads = config.attention_heads // weights.process_group.size() + + self.qkv_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_state) + query, key, value = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + ], + dim=2, + ) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = MllamaVisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 + ) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False + ) + self.gate_ffn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): + super().__init__() + self.config = config + self.layers = [ + MllamaVisionEncoderLayer( + prefix=f"{prefix}.layers.{i}", + config=config, + weights=weights, + is_gated=is_gated, + ) + for i in range(num_layers) + ] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + encoder_states = [hidden_states] + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + ) + + hidden_states = layer_outputs + encoder_states.append(hidden_states) + + return hidden_states, encoder_states + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + + self.embedding = TensorParallelEmbedding( + prefix=f"{prefix}.embedding", weights=weights + ) + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + # Always gated. + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + # position embedding + embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.embedding"), requires_grad=False + ) + self.gated_position_embedding = (1 - self.gate.tanh()) * embedding + self.tile_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.tile_embedding", weights=weights + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + # position embeddings + hidden_state = hidden_state + self.gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaVisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + self.dtype = weights.dtype + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + + self.class_embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False + ) + + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + prefix=f"{prefix}.gated_positional_embedding", + config=config, + weights=weights, + ) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.pre_tile_positional_embedding", + config=config, + weights=weights, + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.post_tile_positional_embedding", + config=config, + weights=weights, + ) + + ## layer norms + self.layernorm_pre = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_pre", + weights=weights, + # torch default + eps=1e-05, + ) + self.layernorm_post = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_post", + weights=weights, + # torch default + eps=1e-05, + ) + + ## encoders + self.transformer = MllamaVisionEncoder( + prefix=f"{prefix}.transformer", + config=config, + weights=weights, + is_gated=False, + num_layers=config.num_hidden_layers, + ) + self.global_transformer = MllamaVisionEncoder( + prefix=f"{prefix}.global_transformer", + config=config, + weights=weights, + is_gated=True, + num_layers=config.num_global_layers, + ) + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + # patch embedding + patch_embeds = self.patch_embedding(pixel_values) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + if attention_mask is not None: + attention_mask = attention_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + hidden_state, all_intermediate_hidden_states = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + intermediate_hidden_states = [ + hidden_state + for idx, hidden_state in enumerate(all_intermediate_hidden_states) + if idx in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + hidden_state, _ = self.global_transformer( + hidden_state, attention_mask=attention_mask + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + return hidden_state + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, *, prefix, config, weights, layer_idx): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_size = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.layer_idx = layer_idx + + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.k_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.v_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.q_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps + ) + self.k_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps + ) + self.softmax_scale = self.head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + # past_key_value=None, + # attention_mask: Optional[torch.Tensor] = None, + # cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # hidden_states = hidden_states.unsqueeze(0) + # bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(-1, self.num_heads, self.head_size) + query_states = self.q_norm(query_states) + + ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + indices, + ) = cross_attention_states + + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) + value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + key_states = self.k_norm(key_states) + + # key_states = key_states.repeat(1, self.num_key_value_groups, 1) + # value_states = value_states.repeat(1, self.num_key_value_groups, 1) + + causal = False + # logger.info( + # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" + # ) + attn_output = flash_attn_2_cuda.varlen_fwd( + query_states, + key_states, + value_states, + None, + cu_seqlen_q, + cu_seqlen_k, + None, + None, + None, # block_tables + None, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, # Causal + -1, # window_size_left, + -1, + 0.0, # softcap + False, + None, + )[0] + attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + return attn_output + + +# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText +class MllamaTextMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + shape = x.shape + gate_up_states = self.gate_up_proj(x) + gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) + result = self.down_proj( + self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] + ) + return result + + +class FlashLlamaCrossLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__(self, *, prefix, config, weights, index) -> None: + layer_idx = index + super().__init__() + self.cross_attn = MllamaTextCrossAttention( + prefix=f"{prefix}.cross_attn", + config=config, + weights=weights, + layer_idx=layer_idx, + ) + + self.input_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.cross_attn_attn_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False + ) + + self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.post_attention_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.cross_attn_mlp_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False + ) + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + cross_attention_states, # [ IB, ...] + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cross_attention_states is None: + return hidden_states, residual + if residual is not None: + hidden_states += residual + + indices = cross_attention_states[-1] + out_hidden_states = hidden_states[:] + if len(indices) > 0: + assert max(indices) < hidden_states.shape[0] + hidden_states = hidden_states[indices] + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + # attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + out_hidden_states[indices] = hidden_states + hidden_states = out_hidden_states + + return hidden_states, None + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class MllamaTextRMSNorm(nn.Module): + def __init__(self, weight, eps): + super().__init__() + self.weight = weight + self.variance_epsilon = eps + + @classmethod + def load(cls, *, prefix, weights, eps): + weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + return cls(weight=weight, eps=eps) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + config.text_config._attn_implementation = "sdpa" + self.hidden_size = config.text_config.hidden_size + self.vision_model = MllamaVisionModel( + prefix="vision_model", config=config.vision_config, weights=weights + ) + self.multi_modal_projector = FastLinear.load( + prefix="multi_modal_projector", config=config, weights=weights, bias=True + ) + self.text_model = FlashLlamaForCausalLM( + prefix="language_model", config=config.text_config, weights=weights + ) + self.config = config + self.dtype = weights.dtype + self.device = weights.device + + def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # logger.info(f"PIxel values {pixel_values.shape}") + batch_size = pixel_values.shape[0] + vision_states = self.vision_model( + pixel_values, aspect_ratio_ids, aspect_ratio_mask + ) + cross_attention_states = self.multi_modal_projector(vision_states).reshape( + -1, vision_states.shape[-2], self.hidden_size + ) + _, _, h = cross_attention_states.shape + cross_attention_states = cross_attention_states.view(batch_size, -1, h) + # logger.info(f"cross {cross_attention_states.shape}") + return cross_attention_states + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, + # XXX: Putting these as optional so that the cuda warmup calls can go through. + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + if cross_attention_states is not None: + seqlen_q = len(image_indices) + n_images = cross_attention_states.shape[0] + seqlen_k = cross_attention_states.shape[1] + device = cross_attention_states.device + if cu_seqlen_prefill is not None: + offset = 0 + cu_q = [] + indices = [] + for index in image_indices: + cu_q.append(offset) + length = seqlen.input_lengths[index].item() + assert index < seqlen.cu_seqlen_q.shape[0] + input_ids_offset = seqlen.cu_seqlen_q[index] + indices.extend(range(input_ids_offset, input_ids_offset + length)) + offset += length + cu_q.append(offset) + cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) + + assert max(indices) < input_ids.shape[0] + + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + max_q = cu_seqlen_q[-1].item() + max_k = seqlen_k + else: + cu_seqlen_q = torch.arange( + seqlen_q + 1, device=device, dtype=torch.int32 + ) + seqlen_k = cross_attention_states.shape[1] + n_images = cross_attention_states.shape[0] + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + max_q = seqlen_q + max_k = seqlen_k + indices = image_indices[:] + + cross_attention_states = ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + indices, + ) + + outputs = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + cross_attention_states=cross_attention_states, + ) + + return outputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py new file mode 100644 index 000000000..988a74a39 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -0,0 +1,1215 @@ +"""A simple, flexible implementation of a GPT model. + +Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py +""" + +import math +import warnings +from typing import List, Optional, Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from einops import rearrange +from packaging import version +from text_generation_server.layers import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, + SpeculativeHead, + get_linear, +) + +EPS = 1e-5 + + +def load_col(config, prefix, weights, bias): + assert config.quantize != "gptq", NotImplementedError + slice_ = weights._get_slice(f"{prefix}.weight") + rank = weights.process_group.rank() + size = weights.process_group.size() + + h3, h = slice_.get_shape() + block_size = h // size + + q_part = slice_[rank * block_size : (rank + 1) * block_size] + k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] + v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] + + weight = torch.cat([q_part, k_part, v_part], dim=0) + if weight.dtype != torch.int32: + weight = weight.to(dtype=weights.dtype) + weight = weight.to(device=weights.device) + + if bias: + bias_slice_ = weights._get_slice(f"{prefix}.bias") + bias_rank = weights.process_group.rank() + bias_size = weights.process_group.size() + + bias_h = bias_slice_.get_shape() + bias_h = bias_h[0] + bias_block_size = bias_h // bias_size + + bias_q_part = bias_slice_[ + bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size + ] + bias_k_part = bias_slice_[ + bias_h + + bias_rank * bias_block_size : bias_h + + (bias_rank + 1) * bias_block_size + ] + bias_v_part = bias_slice_[ + 2 * bias_h + + bias_rank * bias_block_size : 2 * bias_h + + (bias_rank + 1) * bias_block_size + ] + + bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) + if bias.dtype != torch.int32: + bias = bias.to(dtype=weights.dtype) + bias = bias.to(device=weights.device) + else: + bias = None + linear = get_linear(weight, bias) + return TensorParallelColumnLinear(linear) + + +def _reset_is_causal( + num_query_tokens: int, num_key_tokens: int, original_is_causal: bool +): + if original_is_causal and num_query_tokens != num_key_tokens: + if num_query_tokens != 1: + raise NotImplementedError( + "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." + ) + else: + return False + return original_is_causal + + +def scaled_multihead_dot_product_attention( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): + q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) + kv_n_heads = 1 if multiquery else n_heads + k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) + v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) + if past_key_value is not None: + if len(past_key_value) != 0: + k = torch.cat([past_key_value[0], k], dim=3) + v = torch.cat([past_key_value[1], v], dim=2) + past_key_value = (k, v) + (b, _, s_q, d) = q.shape + s_k = k.size(-1) + attn_weight = q.matmul(k) * softmax_scale + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - s_q) + _s_k = max(0, attn_bias.size(3) - s_k) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if ( + attn_bias.size(-1) != 1 + and attn_bias.size(-1) != s_k + or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) + ): + raise RuntimeError( + f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." + ) + attn_weight = attn_weight + attn_bias + min_val = torch.finfo(q.dtype).min + if key_padding_mask is not None: + if attn_bias is not None: + warnings.warn( + "Propogating key_padding_mask to the attention module " + + "and applying it within the attention module can cause " + + "unneccessary computation/memory usage. Consider integrating " + + "into attn_bias once and passing that to each attention " + + "module instead." + ) + attn_weight = attn_weight.masked_fill( + ~key_padding_mask.view((b, 1, 1, s_k)), min_val + ) + if is_causal and (not q.size(2) == 1): + s = max(s_q, s_k) + causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) + causal_mask = causal_mask.tril() + causal_mask = causal_mask.to(torch.bool) + causal_mask = ~causal_mask + causal_mask = causal_mask[-s_q:, -s_k:] + attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p: + attn_weight = torch.nn.functional.dropout( + attn_weight, p=dropout_p, training=training, inplace=True + ) + out = attn_weight.to(v.dtype).matmul(v) + out = rearrange(out, "b h s d -> b s (h d)") + if needs_weights: + return (out, attn_weight, past_key_value) + return (out, None, past_key_value) + + +def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): + for tensor in tensors: + if tensor.dtype not in valid_dtypes: + raise TypeError( + f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." + ) + if not tensor.is_cuda: + raise TypeError( + f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." + ) + + +def flash_attn_fn( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): + try: + from flash_attn import bert_padding, flash_attn_interface + except Exception: + raise RuntimeError("Please install flash-attn==1.0.3.post0") + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - query.size(1)) + _s_k = max(0, attn_bias.size(3) - key.size(1)) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if attn_bias is not None: + raise NotImplementedError("attn_bias not implemented for flash attn.") + (batch_size, seqlen) = query.shape[:2] + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1) :] + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( + query, query_padding_mask + ) + query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( + key, key_padding_mask + ) + key_unpad = rearrange( + key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads + ) + (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) + value_unpad = rearrange( + value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads + ) + if multiquery: + key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) + value_unpad = value_unpad.expand( + value_unpad.size(0), n_heads, value_unpad.size(-1) + ) + dropout_p = dropout_p if training else 0.0 + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + output_unpad = flash_attn_interface.flash_attn_unpadded_func( + query_unpad, + key_unpad, + value_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale=softmax_scale, + causal=reset_is_causal, + return_attn_probs=needs_weights, + ) + output = bert_padding.pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen + ) + return (output, None, past_key_value) + + +def triton_flash_attn_fn( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): + try: + from .flash_attn_triton import flash_attn_func + except Exception: + _installed = False + if version.parse(torch.__version__) < version.parse("2.0.0"): + _installed = True + try: + from flash_attn.flash_attn_triton import flash_attn_func + except Exception: + _installed = False + if not _installed: + raise RuntimeError( + "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." + ) + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - query.size(1)) + _s_k = max(0, attn_bias.size(3) - key.size(1)) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if dropout_p: + raise NotImplementedError("Dropout not implemented for attn_impl: triton.") + if needs_weights: + raise NotImplementedError("attn_impl: triton cannot return attn weights.") + if key_padding_mask is not None: + warnings.warn( + "Propagating key_padding_mask to the attention module " + + "and applying it within the attention module can cause " + + "unnecessary computation/memory usage. Consider integrating " + + "into attn_bias once and passing that to each attention " + + "module instead." + ) + (b_size, s_k) = key_padding_mask.shape[:2] + if attn_bias is None: + attn_bias = query.new_zeros(b_size, 1, 1, s_k) + attn_bias = attn_bias.masked_fill( + ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min + ) + query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) + key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) + value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) + if multiquery: + key = key.expand(*key.shape[:2], n_heads, key.size(-1)) + value = value.expand(*value.shape[:2], n_heads, value.size(-1)) + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + attn_output = flash_attn_func( + query, key, value, attn_bias, reset_is_causal, softmax_scale + ) + output = attn_output.view(*attn_output.shape[:2], -1) + return (output, None, past_key_value) + + +class MultiheadAttention(nn.Module): + """Multi-head self attention. + + Using torch or triton attention implementation enables user to also use + additive bias. + """ + + def __init__( + self, + config, + prefix, + weights, + ): + super().__init__() + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln + self.d_model = config.d_model + d_model = config.d_model + self.n_heads = config.n_heads + self.softmax_scale = config.attn_config.softmax_scale + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) + self.attn_dropout_p = config.attn_config.attn_pdrop + + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.n_heads = self.n_heads // weights.process_group.size() + self.Wqkv = load_col( + config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias + ) + if self.qk_ln: + bias = not config.no_bias + hidden_size = config.d_model + head_dim = hidden_size // self.n_heads + + self.q_ln = LPLayerNorm( + d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights + ) + self.k_ln = LPLayerNorm( + self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights + ) + if self.attn_impl == "flash": + self.attn_fn = flash_attn_fn + elif self.attn_impl == "triton": + self.attn_fn = triton_flash_attn_fn + elif self.attn_impl == "torch": + self.attn_fn = scaled_multihead_dot_product_attention + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=not config.no_bias, + ) + + def forward( + self, + x, + past_key_value=None, + attn_bias=None, + attention_mask=None, + is_causal=True, + needs_weights=False, + ): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.chunk(3, dim=2) + + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + (context, attn_weights, past_key_value) = self.attn_fn( + query, + key, + value, + self.n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + ) + out = self.out_proj(context) + return (out, attn_weights, past_key_value) + + +class MultiQueryAttention(nn.Module): + """Multi-Query self attention. + + Using torch or triton attention implementation enables user to also use + additive bias. + """ + + def __init__(self, config, prefix, weights, verbose=False): + super().__init__() + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln + self.d_model = config.d_model + d_model = config.d_model + self.n_heads = config.n_heads + self.softmax_scale = config.attn_config.softmax_scale + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.head_dim) + self.attn_dropout_p = config.attn_config.attn_pdrop + # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) + self.Wqkv = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias + ) + (d_model, d_model + self.head_dim) + if self.qk_ln: + raise NotImplementedError("qk_ln not supported") + if self.attn_impl == "flash": + self.attn_fn = flash_attn_fn + elif self.attn_impl == "triton": + self.attn_fn = triton_flash_attn_fn + if verbose: + warnings.warn( + "While `attn_impl: triton` can be faster than `attn_impl: flash` " + + "it uses more memory. When training larger models this can trigger " + + "alloc retries which hurts performance. If encountered, we recommend " + + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." + ) + elif self.attn_impl == "torch": + self.attn_fn = scaled_multihead_dot_product_attention + if torch.cuda.is_available() and verbose: + warnings.warn( + "Using `attn_impl: torch`. If your model does not use `alibi` or " + + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + + "we recommend using `attn_impl: triton`." + ) + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=not config.no_bias, + ) + # self.out_proj._is_residual = True + + def forward( + self, + x, + past_key_value=None, + attn_bias=None, + attention_mask=None, + is_causal=True, + needs_weights=False, + ): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.split( + [self.d_model, self.head_dim, self.head_dim], dim=2 + ) + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + (context, attn_weights, past_key_value) = self.attn_fn( + query, + key, + value, + self.n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + multiquery=True, + ) + return (self.out_proj(context), attn_weights, past_key_value) + + +def attn_bias_shape( + attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id +): + if attn_impl == "flash": + return None + elif attn_impl in ["torch", "triton"]: + if alibi: + if (prefix_lm or not causal) or use_sequence_id: + return (1, n_heads, seq_len, seq_len) + return (1, n_heads, 1, seq_len) + elif prefix_lm or use_sequence_id: + return (1, 1, seq_len, seq_len) + return None + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + + +def build_attn_bias( + attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 +): + if attn_impl == "flash": + return None + elif attn_impl in ["torch", "triton"]: + if alibi: + (device, dtype) = (attn_bias.device, attn_bias.dtype) + attn_bias = attn_bias.add( + build_alibi_bias( + n_heads, + seq_len, + full=not causal, + alibi_bias_max=alibi_bias_max, + device=device, + dtype=dtype, + ) + ) + return attn_bias + else: + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + + +def gen_slopes(n_heads, alibi_bias_max=8, device=None): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) + m = m.mul(alibi_bias_max / _n_heads) + slopes = 1.0 / torch.pow(2, m) + if _n_heads != n_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] + return slopes.view(1, n_heads, 1, 1) + + +def build_alibi_bias( + n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None +): + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( + 1, 1, 1, seq_len + ) + if full: + alibi_bias = alibi_bias - torch.arange( + 1 - seq_len, 1, dtype=torch.int32, device=device + ).view(1, 1, seq_len, 1) + alibi_bias = alibi_bias.abs().mul(-1) + slopes = gen_slopes(n_heads, alibi_bias_max, device=device) + alibi_bias = alibi_bias * slopes + return alibi_bias.to(dtype=dtype) + + +ATTN_CLASS_REGISTRY = { + "multihead_attention": MultiheadAttention, + "multiquery_attention": MultiQueryAttention, +} + +"""GPT Blocks used for the GPT Model.""" + + +class MPTMLP(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) + self.up_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias + ) + self.act = nn.GELU(approximate="none") + # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=not config.no_bias, + ) + # self.down_proj._is_residual = True + + def forward(self, x): + return self.down_proj(self.act(self.up_proj(x))) + + +class MPTBlock(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.prefix = prefix + if config.attn_config.attn_type != "multihead_attention": + raise NotImplementedError( + f"""Not implemented attn {config.attn_config.attn_type}""" + ) + resid_pdrop = config.resid_pdrop + if config.no_bias: + self.norm_1 = nn.LayerNorm.load_no_bias( + prefix=f"{prefix}.norm_1", weights=weights, eps=EPS + ) + self.norm_2 = nn.LayerNorm.load_no_bias( + prefix=f"{prefix}.norm_2", weights=weights, eps=EPS + ) + else: + self.norm_1 = nn.LayerNorm.load( + prefix=f"{prefix}.norm_1", weights=weights, eps=EPS + ) + self.norm_2 = nn.LayerNorm.load( + prefix=f"{prefix}.norm_2", weights=weights, eps=EPS + ) + self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) + self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) + self.resid_ffn_dropout = nn.Dropout(resid_pdrop) + + def forward( + self, + x: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_bias: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.ByteTensor] = None, + is_causal: bool = True, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + a = self.norm_1(x) + (b, attn_weights, past_key_value) = self.attn( + a, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=is_causal, + ) + x = x + self.resid_attn_dropout(b) + m = self.norm_2(x) + n = self.ffn(m) + x = x + self.resid_ffn_dropout(n) + return (x, attn_weights, past_key_value) + + +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor + + +class LPLayerNorm(torch.nn.LayerNorm): + def __init__( + self, + normalized_shape, + eps=1e-05, + elementwise_affine=True, + device=None, + dtype=None, + bias: Optional[bool] = True, + prefix=None, + weights=None, + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + device=device, + dtype=dtype, + bias=bias, + ) + if weights is not None: + self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0)) + if bias: + self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) + self.normalized_shape = self.weight.shape + + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) + downcast_bias = ( + _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + ) + with torch.autocast(enabled=False, device_type=module_device.type): + return torch.nn.functional.layer_norm( + downcast_x, + self.normalized_shape, + downcast_weight, + downcast_bias, + self.eps, + ) + + +def rms_norm(x, weight=None, eps=1e-05): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + if weight is not None: + return output * weight + return output + + +class RMSNorm(torch.nn.Module): + def __init__( + self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None + ): + super().__init__() + self.eps = eps + if weight: + self.weight = torch.nn.Parameter( + torch.ones(normalized_shape, dtype=dtype, device=device) + ) + else: + self.register_parameter("weight", None) + + def forward(self, x): + return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) + + +class LPRMSNorm(RMSNorm): + def __init__( + self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + weight=weight, + dtype=dtype, + device=device, + ) + + def forward(self, x): + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) + with torch.autocast(enabled=False, device_type=x.device.type): + return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) + + +NORM_CLASS_REGISTRY = { + "layernorm": torch.nn.LayerNorm, + "low_precision_layernorm": LPLayerNorm, + "rmsnorm": RMSNorm, + "low_precision_rmsnorm": LPRMSNorm, +} + +Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + + +class MPTPreTrainedModel(PreTrainedModel): + base_model_prefix = "model" + _no_split_modules = ["MPTBlock"] + + +class MPTModel(MPTPreTrainedModel): + def __init__(self, prefix: str, config, weights): + # config._validate_config() + super().__init__(config) + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() + self.n_heads = config.n_heads + self.attn_impl = config.attn_config.attn_impl + self.prefix_lm = config.attn_config.prefix_lm + self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id + self.alibi = config.attn_config.alibi + self.alibi_bias_max = config.attn_config.alibi_bias_max + if config.init_device == "mixed": + # TODO: reimplement mixed device initialization + # dist.get_local_rank() == 0: + if True: + config.init_device = "cpu" + else: + config.init_device = "meta" + if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): + norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) + raise NotImplementedError( + f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." + ) + if config.norm_type.lower() != "low_precision_layernorm": + raise NotImplementedError( + f"Requested norm type ({config.norm_type}) is not implemented within this repo." + ) + + self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) + + if not self.alibi: + self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) + self.blocks = nn.ModuleList( + [ + MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) + for i in range(config.n_layers) + ] + ) + if config.no_bias: + self.norm_f = nn.LayerNorm.load_no_bias( + prefix="transformer.norm_f", weights=weights, eps=EPS + ) + else: + self.norm_f = nn.LayerNorm.load( + prefix="transformer.norm_f", weights=weights, eps=EPS + ) + self.is_causal = not self.prefix_lm + self._attn_bias_initialized = False + self.attn_bias = None + self.attn_bias_shape = attn_bias_shape( + self.attn_impl, + config.n_heads, + config.max_seq_len, + self.alibi, + prefix_lm=self.prefix_lm, + causal=self.is_causal, + use_sequence_id=self.attn_uses_sequence_id, + ) + if config.no_bias: + for module in self.modules(): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + if config.verbose: + warnings.warn(f"Removing bias ({module.bias}) from {module}.") + module.register_parameter("bias", None) + if hasattr(self.config, "verbose"): + if config.verbose and config.verbose > 2: + print(self) + if "verbose" not in self.config.init_config: + self.config.init_config["verbose"] = self.config.verbose + if self.config.init_config["verbose"] > 1: + init_fn_name = self.config.init_config["name"] + warnings.warn(f"Using {init_fn_name} initialization.") + + @torch.no_grad() + def _attn_bias( + self, + device, + dtype, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + ): + if not self._attn_bias_initialized: + if self.attn_bias_shape: + self.attn_bias = torch.zeros( + self.attn_bias_shape, device=device, dtype=dtype + ) + self.attn_bias = build_attn_bias( + self.attn_impl, + self.attn_bias, + self.config.n_heads, + self.config.max_seq_len, + causal=self.is_causal, + alibi=self.alibi, + alibi_bias_max=self.alibi_bias_max, + ) + assert self.n_heads % self.world_size == 0 + block_size = self.n_heads // self.world_size + self.attn_bias = self.attn_bias[ + :, self.rank * block_size : (self.rank + 1) * block_size + ] + self._attn_bias_initialized = True + if self.attn_impl == "flash": + return (self.attn_bias, attention_mask) + if self.attn_bias is not None: + self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) + attn_bias = self.attn_bias + if self.prefix_lm: + assert isinstance(attn_bias, torch.Tensor) + assert isinstance(prefix_mask, torch.Tensor) + attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) + if self.attn_uses_sequence_id and sequence_id is not None: + assert isinstance(attn_bias, torch.Tensor) + attn_bias = self._apply_sequence_id(attn_bias, sequence_id) + if attention_mask is not None: + s_k = attention_mask.shape[-1] + if attn_bias is None: + attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) + else: + _s_k = max(0, attn_bias.size(-1) - s_k) + attn_bias = attn_bias[:, :, :, _s_k:] + if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: + raise ValueError( + f"attention_mask shape={attention_mask.shape} " + + f"and prefix_mask shape={prefix_mask.shape} are not equal." + ) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill( + ~attention_mask.view(-1, 1, 1, s_k), min_val + ) + return (attn_bias, None) + + def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): + (s_k, s_q) = attn_bias.shape[-2:] + if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: + raise ValueError( + "attn_bias does not match the expected shape. " + + f"The last two dimensions should both be {self.config.max_length} " + + f"but are {s_k} and {s_q}." + ) + seq_len = prefix_mask.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError( + f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" + ) + attn_bias = attn_bias[..., :seq_len, :seq_len] + causal = torch.tril( + torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) + ).view(1, 1, seq_len, seq_len) + prefix = prefix_mask.view(-1, 1, 1, seq_len) + cannot_attend = ~torch.logical_or(causal, prefix.bool()) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def _apply_sequence_id( + self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor + ): + seq_len = sequence_id.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError( + f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" + ) + attn_bias = attn_bias[..., :seq_len, :seq_len] + cannot_attend = torch.logical_not( + torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) + ).unsqueeze(1) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if attention_mask is not None: + attention_mask = attention_mask.bool() + if prefix_mask is not None: + prefix_mask = prefix_mask.bool() + if not return_dict: + raise NotImplementedError( + "return_dict False is not implemented yet for MPT" + ) + if output_attentions: + if self.attn_impl != "torch": + raise NotImplementedError( + "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." + ) + if ( + attention_mask is not None + and attention_mask[:, 0].sum() != attention_mask.shape[0] + and self.training + ): + raise NotImplementedError( + "MPT does not support training with left padding." + ) + if self.prefix_lm and prefix_mask is None: + raise ValueError( + "prefix_mask is a required argument when MPT is configured with prefix_lm=True." + ) + if self.training: + if self.attn_uses_sequence_id and sequence_id is None: + raise ValueError( + "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + + "and the model is in train mode." + ) + elif self.attn_uses_sequence_id is False and sequence_id is not None: + warnings.warn( + "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " + + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." + ) + S = input_ids.size(1) + assert ( + S <= self.config.max_seq_len + ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" + tok_emb = self.wte(input_ids) + if self.alibi: + x = tok_emb + else: + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + "past_key_values must provide a past_key_value for each attention " + + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." + ) + past_position = past_key_values[0][0].size(1) + if self.attn_impl == "torch": + past_position = past_key_values[0][0].size(3) + if S + past_position > self.config.max_seq_len: + raise ValueError( + f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." + ) + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) + if attention_mask is not None: + pos = torch.clamp( + pos + - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ + :, past_position: + ], + min=0, + ) + pos_emb = self.wpe(pos) + x = tok_emb + pos_emb + (attn_bias, attention_mask) = self._attn_bias( + device=x.device, + dtype=torch.float32, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + ) + if use_cache and past_key_values is None: + past_key_values = [() for _ in range(self.config.n_layers)] + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + for b_idx, block in enumerate(self.blocks): + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + past_key_value = ( + past_key_values[b_idx] if past_key_values is not None else None + ) + (x, attn_weights, past_key_value) = block( + x, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=self.is_causal, + ) + if past_key_values is not None: + past_key_values[b_idx] = past_key_value + if output_attentions: + assert all_self_attns is not None + all_self_attns = all_self_attns + (attn_weights,) + x = self.norm_f(x) + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + return BaseModelOutputWithPast( + last_hidden_state=x, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MPTForCausalLM(MPTPreTrainedModel): + def __init__(self, prefix: str, config, weights): + super().__init__(config) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + if not config.tie_word_embeddings: + raise ValueError("MPTForCausalLM only supports tied word embeddings") + self.transformer = MPTModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, prefix=f"{prefix}.wte", weights=weights + ) + self.logit_scale = None + if config.logit_scale is not None: + logit_scale = config.logit_scale + if isinstance(logit_scale, str): + if logit_scale == "inv_sqrt_d_model": + logit_scale = 1 / math.sqrt(config.d_model) + else: + raise ValueError( + f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + ) + self.logit_scale = logit_scale + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + outputs = self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + return_dict=return_dict, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + ) + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) + if self.logit_scale is not None: + if self.logit_scale == 0: + warnings.warn( + f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." + ) + logits *= self.logit_scale + loss = None + if labels is not None: + labels = torch.roll(labels, shifts=-1) + labels[:, -1] = -100 + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) + ) + return ( + CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + if inputs_embeds is not None: + raise NotImplementedError("inputs_embeds is not implemented for MPT yet") + attention_mask = kwargs["attention_mask"].bool() + if attention_mask[:, -1].sum() != attention_mask.shape[0]: + raise NotImplementedError( + "MPT does not support generation with right padding." + ) + if self.transformer.attn_uses_sequence_id and self.training: + sequence_id = torch.zeros_like(input_ids[:1]) + else: + sequence_id = None + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if self.transformer.prefix_lm: + prefix_mask = torch.ones_like(attention_mask) + if kwargs.get("use_cache") is False: + raise NotImplementedError( + "MPT with prefix_lm=True does not support use_cache=False." + ) + else: + prefix_mask = None + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prefix_mask": prefix_mask, + "sequence_id": sequence_id, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", True), + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + """Used by HuggingFace generate when using beam search with kv-caching. + + See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 + for an example in transformers. + """ + reordered_past = [] + for layer_past in past_key_values: + reordered_past += [ + tuple( + (past_state.index_select(0, beam_idx) for past_state in layer_past) + ) + ] + return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py new file mode 100644 index 000000000..06731a6f9 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -0,0 +1,796 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPTNeoX model.""" + +from typing import Optional, Tuple, Union + +import os +import torch +import torch.distributed +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + SpeculativeHead, +) + + +CUSTOM_KERNELS_ENABLED = False +if ( + torch.cuda.is_available() + and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" +): + try: + from custom_kernels import fused_attention_cuda + + CUSTOM_KERNELS_ENABLED = True + except ImportError: + pass + + +def make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.ones( + (target_length, target_length + past_key_values_length), + dtype=torch.bool, + device=device, + ) + mask = mask.triu(1 + past_key_values_length) + + expanded_mask = mask.unsqueeze(0).expand( + batch_size, target_length, target_length + past_key_values_length + ) + return expanded_mask + + +def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, tgt_length, src_length) + + +def prepare_attn_mask( + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, +) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + +class GPTNeoXPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + +class GPTNeoXAttention(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + # ??? TODO + # self.register_buffer( + # "bias", + # torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + # 1, 1, max_positions, max_positions + # ), + # ) + # self.register_buffer("masked_bias", torch.tensor(-1e9)) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, + config.max_position_embeddings, + base=config.rotary_emb_base, + ) + self.rotary_emb.inv_freq = nn.Parameter( + weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") + ) + self.inv_norm_factor = 1.0 / torch.sqrt( + torch.tensor(self.head_size, dtype=torch.float32) + ).to(torch.get_default_dtype()) + + if self.num_attention_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_attention_heads` must be divisible by `num_shards` " + f"(got `num_attention_heads`: {self.num_attention_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_attention_heads = ( + self.num_attention_heads // weights.process_group.size() + ) + self.query_key_value = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True + ) + self.dense = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.dense", weights=weights, bias=True + ) + + def forward( + self, + hidden_states, + position_ids, + attention_mask, + head_mask=None, + layer_past=None, + use_cache=False, + output_attentions=False, + ): + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3) + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query, key, value = qkv.split(self.head_size, -1) + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + key_rot = key[..., : self.rotary_ndims] + + query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len) + + query[..., : self.rotary_ndims] = query_rot + key[..., : self.rotary_ndims] = key_rot + + if CUSTOM_KERNELS_ENABLED: + attn_output, present, attn_weights = fused_attention_cuda.forward( + query, + key, + value, + layer_past, + attention_mask, + head_mask, + self.inv_norm_factor, + self.num_attention_heads, + use_cache, + ) + else: + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # Compute attention + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask + ) + + # Reshape outputs + attn_output = self._merge_heads( + attn_output, self.num_attention_heads, self.head_size + ) + + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + @classmethod + def _split_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view( + tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size + ) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + query = query.reshape( + batch_size * num_attention_heads, query_length, attn_head_size + ) + key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + 1, + dtype=query.dtype, + device=key.device, + ).expand(batch_size * num_attention_heads, query_length, key_length) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=self.inv_norm_factor, + ) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attn_scores.dtype + if input_dtype in [torch.float16, torch.bfloat16]: + attn_scores = attn_scores.to(torch.float) + attn_scores = torch.where( + attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores + ) + attn_scores = attn_scores.view( + batch_size, num_attention_heads, query_length, key_length + ) + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): + super().__init__() + self.true_inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2).float().to(device) / dim) + ) + self.register_buffer("inv_freq", self.true_inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + self.cos_cached = None + self.sin_cached = None + + @staticmethod + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + @staticmethod + def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): + t = torch.arange( + max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype + ) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype) + + def forward(self, q, k, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if ( + seq_len > self.max_seq_len_cached + or self.cos_cached is None + or self.sin_cached is None + ): + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.cos_cached, self.sin_cached = self._create_cos_sin( + self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device + ) + return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids) + + +@torch.jit.script +def rotary_forward(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + + chunk_size = q.shape[-1] // 2 + q1, q2 = q.split(chunk_size, -1) + q_rotated = torch.cat((-q2, q1), dim=-1) + k1, k2 = k.split(chunk_size, -1) + k_rotated = torch.cat((-k2, k1), dim=-1) + + q_embed = (q * cos) + (q_rotated * sin) + k_embed = (k * cos) + (k_rotated * sin) + return q_embed, k_embed + + +class GPTNeoXMLP(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.act = ( + ACT2FN[config.hidden_act] + if "gelu_fast" not in config.hidden_act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + def __init__(self, layer_id, prefix: str, config, weights): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.layers.{layer_id}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.attention = GPTNeoXAttention( + config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights + ) + self.mlp = GPTNeoXMLP( + config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights + ) + + def forward( + self, + hidden_states, + position_ids, + attention_mask=None, + head_mask=None, + use_cache=False, + layer_past=None, + output_attentions=False, + ): + attention_layer_outputs = self.attention( + self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attention_layer_outputs[ + 0 + ] # output_attn: attn_output, present, (attn_weights) + outputs = attention_layer_outputs[1:] + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + hidden_states = mlp_output + attn_output + + if use_cache: + outputs = ( + hidden_states, + ) + outputs # hidden_states, present, (attn_weights) + else: + outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) + + return outputs + + +class GPTNeoXModel(GPTNeoXPreTrainedModel): + def __init__(self, prefix: str, config, weights): + super().__init__(config) + self.config = config + + self.num_attention_heads = config.num_attention_heads + + self.embed_in = TensorParallelEmbedding( + prefix=f"{prefix}.embed_in", weights=weights + ) + self.layers = nn.ModuleList( + [ + GPTNeoXLayer(layer_id, prefix, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.final_layer_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.tp_world_size = weights.process_group.size() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids=None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * self.config.num_hidden_layers) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_length, seq_length + past_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + + hidden_states = inputs_embeds + + # Attention mask. + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[-1] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) + else: + attention_mask = attention_mask.to(hidden_states.device) + + causal_mask = prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + assert self.num_attention_heads % self.tp_world_size == 0 + block_size = self.num_attention_heads // self.tp_world_size + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = layer( + hidden_states, + position_ids=position_ids, + attention_mask=causal_mask, + head_mask=head_mask[i], + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_attentions = all_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.final_layer_norm(hidden_states) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_attentions] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, prefix: str, config, weights): + super().__init__(config) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = GPTNeoXModel(prefix, config, weights) + self.embed_out = SpeculativeHead.load( + config, prefix="embed_out", weights=weights + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are + only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see + `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config.is_decoder = True + >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.gpt_neox( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits, speculative_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # we are doing next-token prediction; shift prediction scores and input ids by one + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return ( + CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + input_shape = input_ids.shape + + # cut decoder_input_ids if past is used + if past_key_values and past_key_values[0] is not None: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + ) + + return model_inputs + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past[:2] + ) + + layer_past[2:], + ) + return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py new file mode 100644 index 000000000..bd4403214 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -0,0 +1,857 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers import OPTConfig +from text_generation_server.layers import ( + FastLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + SpeculativeHead, +) + +EPS = 1e-5 + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full( + (tgt_len, tgt_len), + torch.tensor(torch.finfo(dtype).min, device=device), + device=device, + ) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class OPTLearnedPositionalEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, prefix: str, weights): + super().__init__() + self.offset = 2 + self.weight = nn.Parameter( + weights.get_tensor( + f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" + ) + ) + + def forward( + self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return torch.nn.functional.embedding(positions + self.offset, self.weight) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config, + prefix, + weights, + is_decoder: bool = False, + bias: bool = True, + process_group=None, + ): + super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.dropout = config.dropout + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + process_group = weights.process_group + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // process_group.size() + self.hidden_size = self.hidden_size // process_group.size() + + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias + ) + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias + ) + self.out_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): + super().__init__() + self.process_group = weights.process_group + self.hidden_size = config.hidden_size + prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" + self.self_attn = OPTAttention( + config, + prefix=f"{prefix}.self_attn", + weights=weights, + is_decoder=True, + bias=config.enable_bias, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS + ) + self.fc1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias + ) + self.fc2 = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias + ) + self.final_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + + +class OPTDecoder(OPTPreTrainedModel): + def __init__(self, prefix: str, config: OPTConfig, weights): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + prefix = prefix + "." if prefix else "" + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}decoder.embed_tokens", weights=weights + ) + self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = FastLinear.load( + config, + prefix=f"{prefix}decoder.project_out", + weights=weights, + bias=False, + ) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = FastLinear.load( + config, + prefix=f"{prefix}decoder.project_in", + weights=weights, + bias=False, + ) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [ + OPTDecoderLayer(layer_id, prefix, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + causal_attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OPTModel(OPTPreTrainedModel): + def __init__(self, prefix: str, config: OPTConfig, weights): + super().__init__(config) + self.decoder = OPTDecoder(prefix, config, weights) + # Initialize weights and apply final processing + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +class OPTForCausalLM(OPTPreTrainedModel): + def __init__(self, prefix, config, weights): + super().__init__(config) + + self.model = OPTModel(prefix, config, weights) + + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", + weights=weights, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) + + loss = None + + return ( + CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py new file mode 100644 index 000000000..3f2ed010f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -0,0 +1,336 @@ +# imlementation of the PhiModel and PhiForCausalLM classes + +import torch +import torch.distributed + +import math +from torch import nn +from typing import Optional, List, Tuple +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + FastLinear, +) + + +# PhiConfig is the configuration class for the PhiModel. +class PhiConfig(PretrainedConfig): + def __init__( + self, + vocab_size=51200, + n_positions=2048, + n_embd=2560, + n_layer=32, + n_inner=None, + n_head=32, + rotary_dim=32, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + pad_vocab_size_multiple=64, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + no_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_inner = n_inner + self.n_head = n_head + self.rotary_dim = rotary_dim + + self.layer_norm_epsilon = layer_norm_epsilon + self.tie_word_embeddings = tie_word_embeddings + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.no_bias = no_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +# RotaryEmbedding is a class that implements the rotary embedding. +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)] + inv_freq_len = len(inv_freq) + inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len) + t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1) + freqs = t.matmul(inv_freq) + self.sin = freqs.sin() + self.cos = freqs.cos() + + def apply_rotary_emb_qkv(self, qkv, seqlen_offset): + b_size, seqlen, three, _, _headdim = qkv.shape + if three != 3: + raise Exception("unexpected shape for qkv") + _, rotary_dim = self.cos.shape + rotary_dim = rotary_dim * 2 + q_rot = qkv[:, :, 0, :, :rotary_dim] + q_pass = qkv[:, :, 0, :, rotary_dim:] + k_rot = qkv[:, :, 1, :, :rotary_dim] + k_pass = qkv[:, :, 1, :, rotary_dim:] + q12 = torch.chunk(q_rot, 2, dim=-1) + k12 = torch.chunk(k_rot, 2, dim=-1) + q1, q2 = q12[0], q12[1] + k1, k2 = k12[0], k12[1] + c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1) + s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1) + q_rot = torch.cat( + [ + q1 * c - q2 * s, + q1 * s + q2 * c, + ], + dim=-1, + ) + k_rot = torch.cat( + [ + k1 * c - k2 * s, + k1 * s + k2 * c, + ], + dim=-1, + ) + q = torch.cat([q_rot, q_pass], dim=-1) + k = torch.cat([k_rot, k_pass], dim=-1) + v = qkv[:, :, 2] + return q, k, v + + +# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm. +class PhiCausalLMHead(nn.Module): + def __init__(self, config, weights): + super().__init__() + self.ln = nn.LayerNorm.load( + prefix="lm_head.ln", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.linear = SpeculativeHead.load( + config=config, prefix="lm_head.linear", weights=weights + ) + + def forward(self, hidden_states): + hidden_states = self.ln(hidden_states) + hidden_states = self.linear(hidden_states) + return hidden_states + + +# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens. +class PhiMHA(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.Wqkv = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias + ) + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=not config.no_bias, + ) + self.op_size = config.n_embd + self.head_dim = int(config.n_embd / config.n_head) + self.num_heads = config.n_head + self.rotary_emb = RotaryEmbedding( + config.rotary_dim, + config.n_positions, + ) + self.softmax_scale = 1.0 / math.sqrt(self.head_dim) + + def forward( + self, + hidden_states, + past_kv_cache, + attention_mask=None, + ): + b_size, seq_len, _n_embd = hidden_states.shape + qkv = self.Wqkv(hidden_states) + qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim) + seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1] + q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset) + + # if there is a kv_cache, then we need to concatenate + if past_kv_cache is not None: + prev_k, prev_v = past_kv_cache + k = torch.cat([prev_k, k], dim=1) + v = torch.cat([prev_v, v], dim=1) + + past_kv_cache = [k, v] + attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale) + + if attention_mask is not None: + seqlen_k = k.shape[1] + seqlen_q = q.shape[1] + causal_mask = torch.triu( + torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), + 1, + ) + attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0) + attn_output = ( + attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)) + .transpose(1, 2) + .flatten(-2) + ) + return self.out_proj(attn_output), past_kv_cache + + +# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function. +class PhiMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.n_inner = config.n_inner + self.fc1 = FastLinear.load( + config=config, + prefix=f"{prefix}.fc1", + weights=weights, + bias=False, + ) + self.fc2 = FastLinear.load( + config=config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=False, + ) + self.activation = torch.nn.functional.gelu + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron. +class PhiBlock(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.layer_id = layer_id + self.layer_norm = nn.LayerNorm.load( + prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon + ) + self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights) + self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights) + + def forward( + self, + hidden_states, + kv_cache, + attention_mask, + ): + residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + attn_outputs, past_kv_cache = self.mixer( + hidden_states, kv_cache, attention_mask + ) + feed_forward_hidden_states = self.mlp(hidden_states) + out = attn_outputs + feed_forward_hidden_states + residual + return out, past_kv_cache + + +# PhiModel implements the embedding layer and the transformer blocks. +class PhiModel(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.tp_rank = weights.process_group.rank() + self.tp_world_size = weights.process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embd.wte", weights=weights + ) + self.blocks = nn.ModuleList( + [ + PhiBlock(f"{prefix}.h.{layer_id}", config, weights) + for layer_id in range(config.n_layer) + ] + ) + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + hidden_states = self.embed_tokens(input_ids) + seq_len = hidden_states.shape[1] + mask = None if seq_len <= 1 else attention_mask + + past_key_values = ( + [None] * len(self.blocks) if past_key_values is None else past_key_values + ) + + for index, block in enumerate(self.blocks): + hidden_states, new_key_values = block( + hidden_states, past_key_values[index], mask + ) + past_key_values[index] = new_key_values + + return hidden_states, past_key_values + + +# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. +class PhiForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.model = PhiModel(prefix, config, weights) + self.lm_head = PhiCausalLMHead(config, weights) + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + model_output = self.model( + input_ids, past_key_values, attention_mask, return_dict, use_cache + ) + logits = self.lm_head(model_output[0]) + + loss = None + if labels is not None: + loss = nn.CrossEntropyLoss()( + logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1) + ) + + if not return_dict: + return ( + ((loss,) + (logits,) + model_output[1:]) + if loss is not None + else (logits,) + model_output[1:] + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=model_output[1], + hidden_states=None, + attentions=None, + ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py new file mode 100644 index 000000000..95ac9edee --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py @@ -0,0 +1,410 @@ +from typing import Optional, Tuple +import warnings +import math +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPooling, +) +from transformers import SiglipConfig, SiglipVisionConfig +from torch.nn.init import _calculate_fan_in_and_fan_out + +from text_generation_server.layers.tensor_parallel import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, device=weights.device).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding( + pixel_values + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.head_size = self.head_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=True + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=True + ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=True + ) + self.out_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + # scale post matmul + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(attn_weights.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + def __init__(self, prefix, config: SiglipConfig, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SiglipAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps + ) + self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, None + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(prefix, config, weights) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, prefix, config: SiglipConfig, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + hidden_states, _ = encoder_layer( + hidden_states, + attention_mask, + ) + + return hidden_states + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + self.config = config + + self.embeddings = SiglipVisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = SiglipEncoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + # NOTE: up until this point, the code logits are exactly + # the same as the transformers code. The values evaulate + # slightly differently in our encoder layer. + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + ) + last_hidden_state = encoder_outputs + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + # pooler_output=pooled_output, + # hidden_states=encoder_outputs, + ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py new file mode 100644 index 000000000..e6666acd3 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -0,0 +1,1227 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + +import copy +import math +import warnings +from typing import Optional, Tuple, Union + +from loguru import logger + +import torch +import torch.distributed +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + is_torch_fx_proxy, +) +from transformers import T5Config +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + SpeculativeHead, +) + +# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316 +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +class PartialTPEmbedding(nn.Module): + def __init__(self, prefix: str, weights): + super().__init__() + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + self.weight = nn.Parameter(weight) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.embedding(input, self.weight) + + +@torch.jit.script +def layer_norm(hidden_states, weight, epsilon): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + + # convert into half-precision if necessary + if weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(weight.dtype) + + return weight * hidden_states + + +class T5LayerNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = torch.tensor(eps) + + def forward(self, hidden_states): + return layer_norm(hidden_states, self.weight, self.variance_epsilon) + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" + ) +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(T5LayerNorm) + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + self.wi = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi", weights=weights, bias=False + ) + + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) + self.wo = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.wo", weights=weights, bias=False + ) + weights.dtype = _dtype + config.quantize = _q + + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ( + ACT2FN[config.dense_act_fn] + if "gelu" not in config.dense_act_fn + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) + hidden_states = self.wo(hidden_states) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + self.wi_0 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi_0", weights=weights, bias=False + ) + self.wi_1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi_1", weights=weights, bias=False + ) + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) + self.wo = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.wo", weights=weights, bias=False + ) + weights.dtype = _dtype + config.quantize = _q + + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ( + ACT2FN[config.dense_act_fn] + if "gelu" not in config.dense_act_fn + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) + hidden_states = self.wo(hidden_states) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense( + config, prefix=f"{prefix}.DenseReluDense", weights=weights + ) + else: + self.DenseReluDense = T5DenseActDense( + config, prefix=f"{prefix}.DenseReluDense", weights=weights + ) + + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, config: T5Config, prefix, weights, has_relative_attention_bias=False + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + process_group = weights.process_group + # Mesh TensorFlow initialization to avoid scaling before softmax + assert self.n_heads % process_group.size() == 0 + self.q = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q", weights=weights, bias=False + ) + self.k = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k", weights=weights, bias=False + ) + self.v = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v", weights=weights, bias=False + ) + self.o = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.o", weights=weights, bias=False + ) + if self.n_heads % weights.process_group.size() != 0: + raise ValueError( + f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.n_heads = self.n_heads // process_group.size() + self.inner_dim = self.inner_dim // process_group.size() + + if self.has_relative_attention_bias: + self.relative_attention_bias = PartialTPEmbedding( + prefix=f"{prefix}.relative_attention_bias", weights=weights + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) + + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, prefix, weights, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, + prefix=f"{prefix}.SelfAttention", + weights=weights, + has_relative_attention_bias=has_relative_attention_bias, + ) + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.EncDecAttention = T5Attention( + config, + prefix=f"{prefix}.EncDecAttention", + weights=weights, + has_relative_attention_bias=False, + ) + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + prefix=f"{prefix}.layer.0", + weights=weights, + has_relative_attention_bias=has_relative_attention_bias, + ) + ) + if self.is_decoder: + i = 2 + self.layer.append( + T5LayerCrossAttention( + config, prefix=f"{prefix}.layer.1", weights=weights + ) + ) + else: + i = 1 + + self.layer.append( + T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) + ) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." + " See T5 docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert ( + pad_token_id is not None + ), "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, prefix, weights, embed_tokens): + super().__init__(config) + + self.is_decoder = config.is_decoder + + self.embed_tokens = embed_tokens + self.block = nn.ModuleList( + [ + T5Block( + config, + prefix=f"{prefix}.block.{layer_id}", + weights=weights, + has_relative_attention_bias=(layer_id == 0), + ) + for layer_id in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + prefix=f"{prefix}.final_layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + assert ( + self.embed_tokens is not None + ), "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + assert ( + self.is_decoder + ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + if ( + self.is_decoder + and encoder_attention_mask is None + and encoder_hidden_states is not None + ): + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long, + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class T5ForConditionalGeneration(T5PreTrainedModel): + def __init__(self, config: T5Config, weights): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack( + config=encoder_config, + prefix="encoder", + weights=weights, + embed_tokens=self.shared, + ) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack( + config=decoder_config, + prefix="decoder", + weights=weights, + embed_tokens=self.shared, + ) + + try: + self.lm_head = SpeculativeHead.load( + config, prefix="lm_head", weights=weights + ) + except RuntimeError: + # Some models like t5-small were saved with shared weights unlike flan + # Since they are declared as the same arch we have no choice but hope + # that this is OK instead of using a proper flag. + self.lm_head = SpeculativeHead.load( + config, prefix="shared", weights=weights + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + logits, speculative_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return ( + Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ), + speculative_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + decoder_attention_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) + return reordered_decoder_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py new file mode 100644 index 000000000..e5c44045a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py @@ -0,0 +1,48 @@ +def load_text_model(prefix, config, weights, name=None): + if config.model_type == "llama": + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + ) + + return FlashLlamaForCausalLM(prefix, config, weights) + elif config.model_type == "mistral": + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + + return FlashMistralForCausalLM(prefix, config, weights, name=name) + elif config.model_type == "gemma": + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + + return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + elif config.model_type == "paligemma": + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + + return FlashGemmaForCausalLM(prefix, config, weights) + else: + raise RuntimeError(f"Unsupported model type {config.model_type}") + + +def load_vision_model(prefix, config, weights): + if config.model_type == "clip_vision_model": + from text_generation_server.models.custom_modeling.clip import ( + CLIPVisionTransformer, + ) + + return CLIPVisionTransformer( + prefix=f"{prefix}.vision_model", config=config, weights=weights + ) + if config.model_type == "siglip_vision_model": + from text_generation_server.models.custom_modeling.siglip import ( + SiglipVisionTransformer, + ) + + return SiglipVisionTransformer( + prefix="vision_tower.vision_model", config=config, weights=weights + ) + else: + raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py new file mode 100644 index 000000000..bc9d44a0b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -0,0 +1,2003 @@ +from contextlib import nullcontext +import math +import os +import time +import torch +import torch.distributed + +import numpy as np + +from loguru import logger +from dataclasses import dataclass +from opentelemetry import trace +from transformers import ( + PreTrainedTokenizerBase, + AutoConfig, + AutoTokenizer, + GenerationConfig, +) +from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict + +from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models import Model +from text_generation_server.utils.log import log_master +from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.globals import ( + MEM_POOL, + ATTENTION, + BLOCK_SIZE, + CUDA_GRAPHS, + TGI_WIGGLE_ROOM, + get_adapter_to_index, +) +from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser +from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.quantization import get_loader +from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments + +from text_generation_server.utils.import_utils import ( + empty_cache, + synchronize, + get_free_memory, +) + +tracer = trace.get_tracer(__name__) + + +# Will be set in init +SLIDING_WINDOW: Optional[int] = None + + +def set_sliding_window(sliding_window: int): + global SLIDING_WINDOW + SLIDING_WINDOW = sliding_window + + +def get_sliding_windows() -> int: + global SLIDING_WINDOW + return SLIDING_WINDOW + + +def init_cpu_threads_env(rank_id: int, world_size: int): + import importlib.util + + if importlib.util.find_spec("numa") is not None: + import numa + import psutil + + nodes = numa.info.get_max_node() + 1 + rank_per_node = math.ceil(world_size / nodes) + num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) + node_id = int(rank_id / rank_per_node) + rank_offset_per_node = rank_id % rank_per_node + if os.getenv("OMP_NUM_THREADS") is None: + num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) + else: + num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) + if len(numa.memory.get_membind_nodes()) == nodes: + numa.memory.set_membind_nodes((node_id)) + torch.set_num_threads(num_cpus_per_rank) + if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True): + cpu_start = num_cpus_per_rank * rank_offset_per_node + numa.schedule.run_on_cpus( + 0, + *( + numa.info.node_to_cpus(node_id)[ + cpu_start : cpu_start + num_cpus_per_rank + ] + ), + ) + logger.info( + f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}" + ) + + +@dataclass +class FlashCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + # request id -> idx in list mapping + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + position_ids: torch.Tensor + speculative_ids: Optional[torch.Tensor] + + # Flash Attention values + + # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + + # Paged Attention values + + # Set when creating the batch + # CPU tensor of length b indicating the start of each sequence in slots + start_slots: torch.Tensor + # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode + slot_indices: torch.Tensor + + # list of length b of list of length s_i // block_size + block_tables: List[List[int]] + # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences + block_tables_tensor: torch.Tensor + # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences + slots: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor + + max_seqlen: int + + # Prefill metadata tensors to efficiently compute logprobs + prefill_head_indices: Optional[torch.Tensor] + prefill_next_token_indices: Optional[torch.tensor] + prefill_cu_outlens: Optional[List[int]] + + # Prefixes + prefix_ids: List[List[int]] + + # All tokens + all_input_ids: List[List[int]] + all_input_ids_tensor: torch.Tensor + + # Lengths of all generations present in the batch + input_lengths: List[int] + input_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] + read_offsets: List[Optional[int]] + + # Generation helpers + next_token_chooser: HeterogeneousNextTokenChooser + stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + # Adapter metadata for each request + adapter_meta: AdapterBatchMetadata + + # Number of blocks in this batch + num_blocks: int + # Maximum number of blocks + max_blocks: int + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.num_blocks * BLOCK_SIZE, + ) + + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer + ): + max_length = 0 + all_input_ids = [] + batch_size = 0 + for r in requests: + batch_size += 1 + inputs = concat_text_chunks(r.input_chunks.chunks) + input_ids = tokenizer( + inputs, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + all_input_ids.append(input_ids) + return all_input_ids + + @classmethod + def from_tokenized( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + batch_tokenized_inputs, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": + sliding_window = get_sliding_windows() + position_ids = [] + cu_seqlen_prefill = [0] + start_slots = [] + slot_indices = [] + prefill_cache_indices = [] + + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + prefix_ids = [] + requests_idx_mapping = {} + + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + next_token_chooser_parameters = [] + stopping_criterias = [] + top_n_tokens = [] + + adapter_indices_list = [] + adapter_set = set() + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + num_blocks = 0 + max_seqlen = 0 + max_length = 0 + max_blocks = 0 + + block_tables = [] + slots = [] + prefix_lens = [] + + # Parse batch + for i, (r, tokenized_input) in enumerate( + zip(pb.requests, batch_tokenized_inputs) + ): + # request id -> idx in list mapping + requests_idx_mapping[r.id] = i + + orig_input_length = len(tokenized_input) + + prefix_len = r.prefix_len + assert ( + prefix_len <= orig_input_length + ), f"Prefix {prefix_len} vs input {orig_input_length}" + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 + + # Commented as it's costly. + # log_master(logger.debug, "Tokenized input ids {tokenized_input}") + prefix_ids.append(tokenized_input[:prefix_len]) + tokenized_input = tokenized_input[prefix_len:] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + + prefix_offsets.append(input_length - 5) + read_offsets.append(input_length) + + all_input_ids.append(tokenized_input) + + # Position ids + request_position_ids = torch.arange( + prefix_len, orig_input_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + input_length) + + next_token_chooser_parameters.append(r.parameters) + + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + max_new_tokens = stopping_criteria.max_new_tokens + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append(torch.full((input_length,), adapter_index)) + adapter_set.add(adapter_index) + + # Paged attention + # Remove one as the first token des not have a past + speculative_length = get_speculate() + speculative_length = 0 if speculative_length is None else speculative_length + + # Tokens that need to be mapped to blocks. + block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to slots. We don't need slots for the + # cached prefix (if present). + slot_tokens = input_length + max_new_tokens - 1 + speculative_length + + # blocks and slots can be empty (for example in warmup) + if not r.blocks: + needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) + request_blocks = [ + b for b in range(num_blocks, num_blocks + needed_blocks) + ] + request_slots = [ + s + for b in request_blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_blocks = r.blocks + request_slots = r.slots[ + prefix_len: #: orig_input_length + max_new_tokens + speculative_length + ] + + block_tables.append(request_blocks) + + slots.extend(request_slots) + prefix_lens.append(prefix_len) + num_blocks += len(request_blocks) + start_slots.append(cumulative_slot_tokens) + + request_slot_indices = torch.arange( + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + slot_indices.append(request_slot_indices) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + prefill_cache_indices.append(request_prefill_cache_indices) + + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], dtype=torch.int32 + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length + cumulative_slot_tokens += slot_tokens + max_seqlen = max(max_seqlen, input_length) + max_blocks = max(max_blocks, len(request_blocks)) + max_length = max( + max_length, input_length + max_new_tokens + speculative_length + ) + + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, dtype, device, tokenizer + ) + start_slots = torch.tensor(start_slots, dtype=torch.int64) + + # Padded all_input_ids_tensor + all_input_ids_tensor = np.zeros( + (len(all_input_ids), max_length), dtype=np.int64 + ) + for i, input_ids in enumerate(all_input_ids): + all_input_ids_tensor[i, : len(input_ids)] = input_ids + + # Create tensors on device + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) + + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + position_ids = position_ids.to(device) + slot_indices = slot_indices.to(device) + prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + input_lengths_tensor = torch.tensor( + input_lengths, dtype=torch.int32, device=device + ) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + + slots = torch.tensor(slots, dtype=torch.int64, device=device) + + block_tables_tensor = torch.zeros( + (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + ) + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + block_tables_tensor = block_tables_tensor.to(device) + prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + prefill_cache_indices=prefill_cache_indices, + start_slots=start_slots, + slot_indices=slot_indices, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + max_seqlen=max_seqlen, + prefill_head_indices=prefill_head_indices, + prefill_next_token_indices=prefill_next_token_indices, + prefill_cu_outlens=prefill_cu_outlens, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, + next_token_chooser=next_token_chooser, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + num_blocks=num_blocks, + max_blocks=max_blocks, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), + speculative_ids=None, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": + assert len(pb.requests) > 0 + batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) + return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + # We assume that if len(requests) == len(self) then the requests are the same + if len(request_ids) == len(self): + return self + + device = self.input_ids.device + + # New values after filtering + requests_idx_mapping = {} + + # Used to index into tensors + indices = [] + + # slots to keep after filtering + slot_filtering_indices = torch.zeros( + self.slots.shape[0], dtype=torch.bool, device=device + ) + + # Create on CPU to only move to GPU once instead of at every copy + slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + max_seqlen = 0 + + requests = [] + start_slots = [] + block_tables = [] + all_input_ids = [] + prefix_ids = [] + + input_lengths = [] + prefix_lens = [] + prefix_offsets = [] + read_offsets = [] + + stopping_criterias = [] + top_n_tokens = [] + adapter_set = set() + + num_blocks = 0 + max_blocks = 0 + # Cumulative length + cumulative_max_length = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + requests_idx_mapping[request_id] = i + + requests.append(self.requests[idx]) + + # Get length + request_input_length = self.input_lengths[idx] + prefix_len = self.prefix_lens[idx] + max_seqlen = max(max_seqlen, request_input_length) + + all_input_ids.append(self.all_input_ids[idx]) + prefix_ids.append(self.prefix_ids[idx]) + + input_lengths.append(request_input_length) + prefix_lens.append(prefix_len) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + + top_n_tokens.append(self.top_n_tokens[idx]) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) + adapter_set.add(adapter_index) + + remaining_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + + request_block_table = self.block_tables[idx] + num_blocks += len(request_block_table) + block_tables.append(request_block_table) + start_slots.append(cumulative_max_length) + + # Copy to tensor (CPU) + slot_indices[i] = cumulative_max_length + request_input_length - 1 + + # Set slice + slot_filtering_indices[ + self.start_slots[idx] : self.start_slots[idx] + + request_input_length + + remaining_tokens + - 1 + ] = True + + cumulative_max_length += request_input_length + remaining_tokens - 1 + + max_blocks = max(max_blocks, len(request_block_table)) + + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + all_input_ids_tensor = self.all_input_ids_tensor[indices] + block_tables_tensor = self.block_tables_tensor[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] + prefix_lens_tensor = self.prefix_lens_tensor[indices] + next_token_chooser = self.next_token_chooser.filter(indices) + top_n_tokens_tensor = self.top_n_tokens_tensor[indices] + speculative_ids = ( + self.speculative_ids[indices] if self.speculative_ids is not None else None + ) + + start_slots = torch.tensor(start_slots, dtype=torch.int64) + + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + + return type(self)( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + start_slots=start_slots, + slot_indices=slot_indices, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, + max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, + next_token_chooser=next_token_chooser, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + num_blocks=num_blocks, + max_blocks=max_blocks, + speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), + ) + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + # Batch attributes + requests = [] + requests_idx_mapping = {} + + num_blocks = 0 + total_batch_size = 0 + total_slots = 0 + max_blocks = 0 + max_length = 0 + max_seqlen = 0 + for b in batches: + total_batch_size += len(b) + total_slots += len(b.slots) + num_blocks += b.num_blocks + speculative_length = ( + b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 + ) + max_blocks = max(max_blocks, b.max_blocks) + max_seqlen = max(max_seqlen, b.max_seqlen) + max_length = max( + max_length, + max( + input_length + + stopping_criteria.max_new_tokens + + speculative_length + - stopping_criteria.current_tokens + for input_length, stopping_criteria in zip( + b.input_lengths, b.stopping_criterias + ) + ), + ) + + input_ids = batches[0].input_ids.new_empty(total_batch_size) + position_ids = batches[0].position_ids.new_empty(total_batch_size) + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + block_tables_tensor = batches[0].block_tables_tensor.new_zeros( + (total_batch_size, max_blocks) + ) + prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) + all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( + (total_batch_size, max_length) + ) + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_set = set() + adapter_segment_builder = SegmentConcatBuilder() + + start_slots = [] + block_tables = [] + prefix_lens = [] + all_input_ids = [] + prefix_ids = [] + + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + + next_token_chooser_parameters = [] + fsm_grammar_states = [] + stopping_criterias = [] + top_n_tokens = [] + + # Cumulative length + cumulative_batch_size = 0 + cumulative_slots = 0 + cumulative_adapter_indices_size = 0 + + for i, batch in enumerate(batches): + requests.extend(batch.requests) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size + + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + + # Copy tensors (GPU) + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + slots[slots_start_index:slots_end_index] = batch.slots + + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) + + all_input_ids_tensor[ + start_index:end_index, : batch.all_input_ids_tensor.shape[1] + ] = batch.all_input_ids_tensor[:, :max_length] + + block_tables_tensor[ + start_index:end_index, : batch.block_tables_tensor.shape[1] + ] = batch.block_tables_tensor[:, :max_blocks] + + prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + + start_slots.append(batch.start_slots + cumulative_slots) + + block_tables.extend(batch.block_tables) + prefix_lens.extend(batch.prefix_lens) + all_input_ids.extend(batch.all_input_ids) + prefix_ids.extend(batch.prefix_ids) + + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) + fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) + stopping_criterias.extend(batch.stopping_criterias) + + top_n_tokens.extend(batch.top_n_tokens) + + # Update + cumulative_batch_size += len(batch) + cumulative_slots += len(batch.slots) + + start_slots = torch.concat(start_slots) + + # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, + dtype=batches[0].next_token_chooser.dtype, + device=batches[0].next_token_chooser.device, + tokenizer=batches[0].next_token_chooser.tokenizer, + fsm_grammar_states=fsm_grammar_states, + ) + + speculative_ids = ( + torch.cat([b.speculative_ids for b in batches], dim=0) + if batches[0].speculative_ids is not None + else None + ) + + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + start_slots=start_slots, + slot_indices=slot_indices, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + slots=slots, + max_seqlen=max_seqlen, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, + next_token_chooser=next_token_chooser, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + num_blocks=num_blocks, + max_blocks=max_blocks, + speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), + ) + + def __len__(self): + return len(self.requests) + + +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + + +class FlashCausalLM(Model): + def __init__( + self, + model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], + tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + config_class: PreTrainedTokenizerBase = AutoConfig, + default_dtype=torch.float16, + aliases=None, + # Used for Santacoder override of config + num_kv_heads: Optional[int] = None, + # Deepseek V2 uses different QK and V dims. + head_size: Optional[int] = None, + skip_special_tokens: bool = True, + ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + init_cpu_threads_env(rank_id=rank, world_size=world_size) + else: + raise NotImplementedError(f"{model_class} is only available on GPU") + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass + + config = config_class.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + + torch.distributed.barrier(group=self.process_group) + + weights_loader = get_loader(quantize, model_id, revision) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, + ) + + prefix = "" + model = model_class(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) + + # VLM models define the config we care about in their text_config + text_config = getattr(config, "text_config", None) + if text_config is not None: + config = text_config + + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + + self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads // self.process_group.size() + # Validation is done in the model itself + if num_kv_heads is None: + num_kv_heads = getattr(config, "num_key_value_heads", None) + # GPT-2 workaround + if num_kv_heads is None: + num_kv_heads = getattr(config, "n_head", None) + if num_kv_heads is None: + raise ValueError("Cannot get the number of key/value heads") + self.num_kv_heads = ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 + + if head_size is None: + # Some models use GQA and different sizes for o_proj + # and q_proj, that allows for that. + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = config.hidden_size // config.num_attention_heads + else: + self.head_size = head_size + + self.cuda_graphs = {} + self.kv_cache = [] + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_prefill_state, + create_decode_state, + create_prefill_with_paged_kv_state, + ) + + self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) + + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + sliding_window=config.sliding_window, + ) + + @property + def batch_type(self) -> Type[FlashCausalLMBatch]: + return FlashCausalLMBatch + + def max_past(self) -> int: + return getattr(self.model, "max_past", None) + + def init_kv_cache( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.kv_cache = [] + empty_cache() + + element_size = torch.tensor([], dtype=dtype).element_size() + if SYSTEM == "ipex" and device.type == "xpu": + x = 1 + else: + x = BLOCK_SIZE // element_size + + if ATTENTION in {"flashdecoding", "flashinfer"}: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + elif SYSTEM == "ipex" and device == torch.device("cpu"): + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + ( + torch.zeros( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.zeros( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths = [max_s] * bs + prefix_lengths = [0] * bs + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lengths, + ) + from text_generation_server.layers.attention.flashinfer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables, + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + else: + state = None + + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": self.kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths_tensor, + "prefix_lengths": prefix_lengths_tensor, + "state": state, + "graph": graph, + } + + torch.cuda.synchronize() + # Run once outside to warmup + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths_tensor=input_lengths_tensor, + state=state, + prefix_lens_tensor=prefix_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + del seqlen + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + torch.cuda.synchronize() + + def warmup(self, batch: FlashCausalLMBatch): + # The warmup batch is the biggest batch we could ever receive + empty_cache() + + try: + self.init_kv_cache( + batch.num_blocks, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + max_bt = batch.max_blocks + max_s = max_bt * BLOCK_SIZE + + if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): + torch.cuda.tunable.tuning_enable(False) + _, batch, _ = self.generate_token(batch) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) from e + + synchronize(self.device) + + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + + free_memory = get_free_memory(self.device, MEMORY_FRACTION) + batch_num_blocks = batch.num_blocks if batch is not None else 0 + + num_blocks = ( + # Leave 5% for some wiggle room + int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) + # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + + batch_num_blocks + ) + + del batch + + self.init_kv_cache( + num_blocks, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + + if SYSTEM == "rocm": + if ( + os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None + or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" + ): + torch.cuda.tunable.enable() + + if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": + torch.cuda.tunable.tuning_enable(True) + + if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None: + tuning_sequences = [ + int(val) + for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") + ] + elif CUDA_GRAPHS is not None: + tuning_sequences = CUDA_GRAPHS + else: + tuning_sequences = [1, 2, 3, 4, 5, 6, 7] + + tunableop_filepath = os.path.join( + HUGGINGFACE_HUB_CACHE, + f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + ) + + log_master( + logger.info, + f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", + ) + + torch.cuda.tunable.set_filename( + tunableop_filepath, insert_device_ordinal=False + ) + + if os.path.isfile(tunableop_filepath): + log_master( + logger.info, + f"The file {tunableop_filepath} already exists and will be reused.", + ) + torch.cuda.tunable.read_file(tunableop_filepath) + + os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) + + for seqlen in tuning_sequences: + log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") + self.tunableop_warmup(seqlen) + torch.cuda.tunable.write_file(tunableop_filepath) + if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": + torch.cuda.tunable.tuning_enable(False) + else: + log_master( + logger.info, + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", + ) + + if CUDA_GRAPHS: + try: + log_master( + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + ) + # Warmup cuda graphs + for bs in CUDA_GRAPHS: + if self.speculate is None or self.speculate + 1 <= bs: + self.cuda_graph_warmup(bs, max_s, max_bt) + except torch.cuda.OutOfMemoryError: + logger.exception("Decode cuda graph warmup failed") + else: + log_master( + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + ) + + return int(num_blocks * BLOCK_SIZE) + + def tunableop_warmup(self, seqlen: int): + input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) + + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ) + max_s = seqlen + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=1, + max_k=seqlen, + ) + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=self.kv_cache, + block_tables=None, + seqlen=seqlen, + slots=slots, + max_s=max_s, + lm_head_indices=None, + prefill_cache_indices=None, + ) + + def forward( + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + + if cu_seqlen_prefill is not None or cuda_graph is None: + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths_tensor=input_lengths, + prefix_lens_tensor=prefix_lens_tensor, + ): + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, speculative_logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + # assert block_tables.shape[0] >= slots.shape[0] + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["prefix_lengths"].zero_() + cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor + + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + prefix_lens_tensor=cuda_graph["prefix_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: FlashCausalLMBatch + ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + prefill = batch.cu_seqlen_prefill is not None + prefill_logprobs = batch.prefill_next_token_indices is not None + + # Update adapter indices for speculative tokens (if present) + adapter_meta = batch.adapter_meta + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1) + .expand(B, new_length) + .reshape(-1) + ) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) + + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + prefill, + batch.prefill_head_indices, + ) + + out, speculative_logits = self.forward(batch, adapter_data) + + if prefill: + next_token_logits = ( + out[batch.prefill_next_token_indices] if prefill_logprobs else out + ) + if speculative_logits is not None: + speculative_logits = ( + speculative_logits[batch.prefill_next_token_indices] + if prefill_logprobs + else speculative_logits + ) + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) + + else: + next_token_logits = out + next_adapter_indices = batch.adapter_meta.adapter_indices + + speculate = get_speculate() + ( + next_input_ids, + next_token_logprobs, + logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_seqlen], + next_token_logits, + speculate, + batch.speculative_ids, + speculative_logits, + ) + + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids + ) + + if prefill: + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) + + next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + # We do not need cu_seqlen_prefill anymore + batch.cu_seqlen_prefill = None + else: + prefill_logprobs = None + next_position_ids = batch.position_ids + + # Cumulative length + cumulative_length = 0 + + # Results + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) + + # We do two for loops as the first one can run completely asynchronously from the GPU while for the second + # one, we need to first do a GPU <-> CPU sync + # It is faster if we delay this sync for the maximum amount of time + + # For each member of the batch + index = 0 + for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): + # Indexing metadata + start_index = cumulative_length + end_index = cumulative_length + input_length + + if prefill: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + out_length = out_end_index - out_start_index + + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] + + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] + + # Used to gather prefill logprobs + # Copy batch.input_ids to prefill_token_indices + if prefill_logprobs: + if len(batch) > 1: + prefill_tokens_indices[out_start_index : out_end_index - 1] = ( + batch.input_ids[start_index + 1 : start_index + out_length] + ) + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = batch.input_ids[ + start_index + 1 : start_index + out_length + ] + + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] + index += 1 + + cumulative_length += input_length + + # Update values + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + + if prefill and prefill_logprobs: + # Get prefill logprobs + prefill_logprobs_tensor = torch.log_softmax(out, -1) + prefill_logprobs = torch.gather( + prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) + ) + # GPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # GPU <-> CPU sync + next_token_logprobs = next_token_logprobs.tolist() + next_token_ids = next_input_ids.tolist() + accepted_ids = accepted_ids.tolist() + start_decode = time.time_ns() + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + batch.stopping_criterias, + batch.all_input_ids, + batch.prefix_ids, + batch.next_token_chooser.do_sample, + batch.next_token_chooser.seeds, + batch.top_n_tokens, + accepted_ids, + batch_top_token_ids, + batch_top_token_logprobs, + ) + + # For each member of the batch + index = 0 + for i, ( + request, + input_length, + prefix_offset, + read_offset, + stopping_criteria, + all_input_ids, + prefix_ids, + do_sample, + seed, + top_n_tokens, + n_accepted_ids, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Append next token to all tokens + next_token_texts = [] + left = 0 + + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") + + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) + + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + index += n_accepted_ids + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # Prefill + if prefill and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + # Remove generated token to only have prefill and add nan for first prompt token + request_prefill_logprobs = ( + [float("nan")] * (len(prefix_ids) + 1) + ) + prefill_logprobs[out_start_index : out_end_index - 1] + prefill_token_ids = all_input_ids[:-1] + prefill_texts = self.tokenizer.batch_decode( + prefix_ids + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + + prefill_tokens = Tokens( + prefix_ids + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single(i, next_token_id) + ) + + # Update values + batch.input_lengths[i] = input_length + n_accepted_ids + if batch.input_lengths[i] > batch.max_seqlen: + batch.max_seqlen = batch.input_lengths[i] + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.all_input_ids[i] = all_input_ids + + if stopped: + # No need to return a batch if we know that all requests stopped + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) + + def _forward_context( + self, + *, + block_tables: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + input_lengths_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, + state: Optional[Any] = None, + ) -> ContextManager: + if ATTENTION != "flashinfer": + return nullcontext() + + from text_generation_server.layers.attention.flashinfer import ( + use_decode_state, + use_prefill_with_paged_kv_state, + ) + + # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + + if cu_seqlen_prefill is not None: + return use_prefill_with_paged_kv_state( + state=( + state if state is not None else self.prefill_with_paged_kv_state + ), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, + cu_seqlens=cu_seqlen_prefill, + input_lengths=input_lengths_tensor + prefix_lens_tensor, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, + ) + else: + assert input_lengths_tensor is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths_tensor + prefix_lens_tensor, + block_tables=block_tables, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, + ) + + +def block_tables_to_ragged( + *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(prefix_lens) + + total_len = sum(input_lengths) + sum(prefix_lens) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + offset = 0 + for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): + seq_len = prefix_len + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged diff --git a/backends/gaudi/server/text_generation_server/models/galactica.py b/backends/gaudi/server/text_generation_server/models/galactica.py new file mode 100644 index 000000000..7c4e462c7 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/galactica.py @@ -0,0 +1,156 @@ +import re +import torch +import torch.distributed + + +from transformers import ( + PreTrainedTokenizerBase, +) +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, +) +from text_generation_server.utils.chunks import concat_text_chunks + +# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py + +# we split individual characters inside special tokens like [START_DNA] +CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") + +# token added to implement a custom sequence tokenization. This token is added at +# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance +# that they do not occur in the corpus. The digits are escaped so that the token does not appear +# literally in the source code in case we ever include it in the training data. +SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" + + +def _insert_split_marker(m: re.Match): + """ + Applies split marker based on a regex match of special tokens such as + [START_DNA]. + Parameters + ---------- + n : str + Input text to split + Returns + ---------- + str - the text with the split token added + """ + start_token, _, sequence, end_token = m.groups() + sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) + return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" + + +def escape_custom_split_sequence(text): + """ + Applies custom splitting to the text for GALILEO's tokenization + Parameters + ---------- + text : str + Input text to split + Returns + ---------- + str - the text with the split token added + """ + return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) + + +# END CREDIT + + +class GalacticaCausalLMBatch(CausalLMBatch): + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "GalacticaCausalLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + prefix_offsets = [] + top_n_tokens = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + # Add escape_custom_split_sequence to the CausalLMBatch logic + inputs.append( + escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks)) + ) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(0) + read_offsets.append(input_len) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + + input_ids = tokenized_inputs["input_ids"] + # Allocate maximum attention_mask + attention_mask = input_ids.new_zeros( + (pb.size, max_input_length + padding_right_offset) + ) + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + + max_tokens = len(inputs) * max_input_length + max_decode_tokens + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py new file mode 100644 index 000000000..a48ea6de5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -0,0 +1,73 @@ +import torch +import os +from typing import Dict, Optional +from loguru import logger +from text_generation_server.utils.log import log_master + +ATTENTION = os.getenv("ATTENTION", "default") +# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" +PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { + "1", + "true", +} +PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"} +log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") +_expected = {"paged", "flashdecoding", "flashinfer", "default"} +assert ( + ATTENTION in _expected +), f"Attention is not valid {ATTENTION}, expected {_expected}" +log_master(logger.info, f"Using Attention = {ATTENTION}") + +if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: + raise RuntimeError("Prefix caching is only supported with flashinfer") + +MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) +assert TGI_WIGGLE_ROOM > 0 +assert TGI_WIGGLE_ROOM < 1 + +# This is overridden by the cli +BLOCK_SIZE: int +if ATTENTION == "flashdecoding": + BLOCK_SIZE = 256 +elif ATTENTION == "flashinfer": + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 + +# This is overridden by the cli +cuda_graphs = os.getenv("CUDA_GRAPHS") +if cuda_graphs is not None: + try: + cuda_graphs = [int(item) for item in cuda_graphs.split(",")] + except Exception as e: + raise RuntimeError( + f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" + ) +else: + cuda_graphs = None + +CUDA_GRAPHS = cuda_graphs + +# This is overridden at model loading. +global MODEL_ID +MODEL_ID = None + + +def set_model_id(model_id: str): + global MODEL_ID + MODEL_ID = model_id + +# NOTE: eventually we should move this into the router and pass back the +# index in all cases. +ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None + + +def set_adapter_to_index(adapter_to_index: Dict[str, int]): + global ADAPTER_TO_INDEX + ADAPTER_TO_INDEX = adapter_to_index + + +def get_adapter_to_index(): + global ADAPTER_TO_INDEX + return ADAPTER_TO_INDEX diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py new file mode 100644 index 000000000..9a7a6fe15 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py @@ -0,0 +1,899 @@ +from io import BytesIO +from PIL import Image +import torch +import time + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizerBase, + ProcessorMixin, +) +from typing import Optional, Tuple, List, Type, Dict + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +import torch.distributed +from text_generation_server.models.custom_modeling.idefics_modeling import ( + IdeficsForVisionText2Text, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +from text_generation_server.utils.quantization import get_loader + +from text_generation_server.utils.import_utils import SYSTEM + + +tracer = trace.get_tracer(__name__) + + +@dataclass +class IdeficsCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + pixel_values: Optional[torch.Tensor] + image_hidden_states: Optional[torch.Tensor] + image_attention_mask: Optional[torch.Tensor] + past_key_values: Optional[List[Tuple]] + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "IdeficsCausalLMBatch": + raise NotImplementedError + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor: ProcessorMixin, # Hack + config, + dtype: torch.dtype, + device: torch.device, + ) -> "IdeficsCausalLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.input_chunks.chunks) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + # TODO Check impact on idefics + prompts = [] + for inp in inputs: + # Each input is encoded into a list, where each element of this input list is either a string or a URL + prompt = [] + for chunk in inp: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + prompt.append(chunk.text) + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + prompt.append(image) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + prompts.append(prompt) + + # The processor replaces the call to tokenizer, and + # a/ takes care of fetching images from the URL + # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model + tokenized_inputs = processor( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_truncation, + # TODO Check impact on idefics + # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append( + input_len - 5 + ) # To decode without potential fallbacks errors + read_offsets.append( + input_len + ) # To decode without potential fallbacks errors + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + + input_ids = tokenized_inputs["input_ids"] + pixel_values = tokenized_inputs.get("pixel_values", None) + image_hidden_states = None + # Allocate maximum attention_mask + attention_mask = input_ids.new_zeros( + (pb.size, max_input_length + padding_right_offset) + ) + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + # Do the same for image_attention_mask + if pixel_values is None: + image_attention_mask = None + else: + image_attention_mask = input_ids.new_zeros( + ( + pb.size, + max_input_length + padding_right_offset, + pixel_values.size(1), + ) + ) + image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ + "image_attention_mask" + ] + + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) + all_input_ids = tokenized_inputs["input_ids"].T.split( + 1, dim=1 + ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list + + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: + # It deletes requests from the batch. For instance when client lost connection + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + position_ids = self.position_ids[keep_indices] + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + ] + # Do the same for pixel_values and image_attention_mask + pixel_values = self.pixel_values[keep_indices] + self.image_attention_mask = self.image_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.image_attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + :, + ] + if self.image_hidden_states is None: + image_hidden_states = None + else: + image_hidden_states = self.image_hidden_states[keep_indices] + + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) is tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.pixel_values = pixel_values + self.image_hidden_states = image_hidden_states + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + return self + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate( + cls, batches: List["IdeficsCausalLMBatch"] + ) -> "IdeficsCausalLMBatch": + # It adds new requests to the batch + # Used for padding + total_batch_size = 0 + max_input_length = 0 + max_num_images = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + max_num_images = max(max_num_images, batch.pixel_values.size(1)) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + max_tokens = 0 + + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + pixel_values = None + image_hidden_states = None + image_attention_mask = None + past_key_values = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset), + ) + + curr_batch_max_num_images = batch.pixel_values.size(1) + if pixel_values is None: + pixel_values = batch.pixel_values.new_zeros( + (total_batch_size, max_num_images, 3, 224, 224) + ) + pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( + batch.pixel_values + ) + + if image_attention_mask is None: + image_attention_mask = batch.image_attention_mask.new_zeros( + ( + total_batch_size, + max_input_length + padding_right_offset, + max_num_images, + ) + ) + + # We need to slice the attention mask to remove padding from previous steps + # and to remove unused allocated space + left_offset = max_input_length - batch.max_input_length + batch_left_offset = ( + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset + ) + attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = batch.attention_mask[ + :, + batch_left_offset : -batch.padding_right_offset, + ] + image_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + :curr_batch_max_num_images, + ] = batch.image_attention_mask[ + :, batch_left_offset : -batch.padding_right_offset, : + ] + + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((total_batch_size, 1)) + position_ids[start_index:end_index] = batch.position_ids + + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + if isinstance(batch.past_key_values[0], tuple): + batch.past_key_values = [ + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values + ] + elif len(batch.past_key_values[0][0].shape) == 3: + for layer in batch.past_key_values: + for k, t in enumerate(layer): + layer[k] = t.view(len(batch), -1, *t.shape[-2:]) + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + start_index = end_index + + first_past_kvs = batches[0].past_key_values + _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape + + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_input_length - 1, + head_dim, + ) + + if batches[0].keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape + else: + # seq_length is last for BLOOM + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_input_length - 1, + ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + for j in range(len(first_past_kvs)): + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if batch.keys_head_dim_last: + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] + ) + else: + # BLOOM case + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] + ) + del past_keys + + start_index = end_index + + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) + del past_values + + # Update values + start_index = end_index + + past_key_values.append([padded_past_keys, padded_past_values]) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + ) + + def __len__(self): + return len(self.requests) + + +class IdeficsCausalLM(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + # 9b seems to work correctly enough in float16, but 80b seems + # to be really saturating for f16. + dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + self.device, self.dtype = device, dtype + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + config.vision_config.quantize = quantize + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) + + model = IdeficsForVisionText2Text(config, weights) + + self.config = config + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @property + def batch_type(self) -> Type[IdeficsCausalLMBatch]: + return IdeficsCausalLMBatch + + def forward( + self, + input_ids, + attention_mask, + position_ids, + pixel_values, + image_hidden_states, + image_attention_mask, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_hidden_states": image_hidden_states, + "image_attention_mask": image_attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + "return_dict": True, + } + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + outputs, speculative_logits = self.model.forward(**kwargs) + return ( + outputs.logits, + speculative_logits, + outputs.past_key_values, + outputs.image_hidden_states, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: IdeficsCausalLMBatch + ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + # slice the attention mask to the correct shape + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + if batch.image_attention_mask is None: + image_attention_mask = None + else: + if batch.input_ids.size(1) == 1: + # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), + # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension + # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated + # token need to attend to the encoder hidden states (i.e. the vision encoder) + # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic + image_attention_mask = batch.image_attention_mask[ + :, -(batch.padding_right_offset + 1) + ].unsqueeze(1) + else: + image_attention_mask = batch.image_attention_mask[ + :, : -batch.padding_right_offset + ] + + logits, speculative_logits, past, image_hidden_states = self.forward( + input_ids=batch.input_ids, + attention_mask=attention_mask, + position_ids=batch.position_ids, + pixel_values=batch.pixel_values, + image_hidden_states=batch.image_hidden_states, + image_attention_mask=image_attention_mask, + past_key_values=batch.past_key_values, + ) + # Hardcoded remove image tokens + logits[:, 32000:32001] = torch.finfo(logits.dtype).min + + start_decode = time.time_ns() + + # Results + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( + batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] + ) + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + # Update past key values + batch.past_key_values = past + batch.image_hidden_states = image_hidden_states + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/mamba.py b/backends/gaudi/server/text_generation_server/models/mamba.py new file mode 100644 index 000000000..f6dcde68a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/mamba.py @@ -0,0 +1,814 @@ +import torch +import torch.distributed +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from typing import Optional +from text_generation_server.models.custom_modeling.mamba_modeling import ( + MambaConfig, +) +from loguru import logger +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL +import time +from text_generation_server.models.custom_modeling.mamba_modeling import ( + MambaModel, + InferenceParams, +) +from text_generation_server.models import Model +from typing import Any, List, Tuple, Type, Dict +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader +from text_generation_server.utils.tokens import batch_top_tokens, Sampling +from dataclasses import dataclass +from text_generation_server.utils import NextTokenChooser, StoppingCriteria + + +def new_inference_params( + n_blocks: int, + batch_size: int, + d_inner: int, + d_conv: int, + d_state: int, + seqlen_offset: int, + dtype: torch.dtype, + device: torch.device, +): + max_seqlen = 0 + conv_states = torch.zeros( + ( + n_blocks, + batch_size, + d_inner, + d_conv, + ), + device=device, + dtype=dtype, + ) + ssm_states = torch.zeros( + ( + n_blocks, + batch_size, + d_inner, + d_state, + ), + device=device, + dtype=dtype, + ) + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_offset, + conv_states=conv_states, + ssm_states=ssm_states, + ) + return inference_params + + +@dataclass +class MambaBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + # Inference params + inference_params: Optional[Dict[str, Any]] = None + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "MambaBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(concat_text_chunks(r.input_chunks.chunks)) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + input_ids = tokenized_inputs["input_ids"] + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + # past_input_ids=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + indices = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + indices.append(idx) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + # TODO + # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. + self.inference_params.conv_states = self.inference_params.conv_states[ + :, indices + ] + self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] + return self + + @classmethod + def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": + # Used for padding + total_batch_size = 0 + max_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + max_tokens = 0 + seqlen_offset = 0 + + (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape + (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape + dtype = batches[0].inference_params.conv_states.dtype + device = batches[0].inference_params.conv_states.device + inference_params = new_inference_params( + n_blocks=n_blocks, + batch_size=total_batch_size, + d_state=d_state, + d_conv=d_conv, + d_inner=d_inner, + seqlen_offset=seqlen_offset, + device=device, + dtype=dtype, + ) + + # Batch tensors + input_ids = None + top_n_tokens_tensor = None + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + inference_params.max_seqlen = max( + inference_params.max_seqlen, batch.inference_params.max_seqlen + ) + assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset" + inference_params.seqlen_offset = max( + inference_params.seqlen_offset, batch.inference_params.seqlen_offset + ) + + inference_params.conv_states[:, start_index:end_index] = ( + batch.inference_params.conv_states + ) + inference_params.ssm_states[:, start_index:end_index] = ( + batch.inference_params.ssm_states + ) + + start_index = end_index + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + inference_params=inference_params, + ) + + def __len__(self): + return len(self.requests) + + +class Mamba(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.quantize = quantize + self.process_group, _rank, world_size = initialize_torch_distributed() + if world_size > 1: + raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") + self.cuda_graphs = {} + if torch.cuda.is_available(): + device = torch.device("cuda") + # Bf16 is important. In f16 accumulations in the matmul are causing + # differences while the server is under load. + # This is detectable by the integration load test + dtype = torch.bfloat16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/gpt-neox-20b", + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + config = MambaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + tokenizer.bos_token_id = config.bos_token_id + tokenizer.eos_token_id = config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + + config.quantize = quantize + config.speculator = speculator + torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) + model = MambaModel(config, weights) + torch.distributed.barrier(group=self.process_group) + super(Mamba, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[MambaBatch]: + return MambaBatch + + def warmup(self, batch) -> Optional[int]: + # TODO: implement warmup for Mamba if needed + if CUDA_GRAPHS: + if self.speculate is None or self.speculate == 0: + try: + logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") + # Warmup cuda graphs + for bs in CUDA_GRAPHS: + self.cuda_graph_warmup(bs) + except Exception: + logger.exception("Decode cuda graph warmup failed") + else: + logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") + + return None + + def cuda_graph_warmup(self, batch_size: int): + input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) + n_blocks = len(self.model.blocks) + + d_state = self.model.config.d_state + d_conv = self.model.config.d_conv + # Inner takes the expand multiplication + d_inner = self.model.config.d_inner + + # Important seqlen_offset to go through the update mecanism with the state + seqlen_offset = 1 + inference_params = new_inference_params( + n_blocks=n_blocks, + batch_size=batch_size, + d_state=d_state, + d_conv=d_conv, + d_inner=d_inner, + seqlen_offset=seqlen_offset, + device=self.device, + dtype=self.dtype, + ) + + graph = torch.cuda.CUDAGraph() + + torch.cuda.synchronize() + # Run once outside to warmup + self.model.forward(input_ids=input_ids, inference_params=inference_params) + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + logits, speculative_logits = self.model.forward( + input_ids=input_ids, inference_params=inference_params + ) + torch.cuda.synchronize() + graph_dict = { + "input_ids": input_ids, + "inference_params": inference_params, + "graph": graph, + "logits": logits, + "speculative_logits": speculative_logits, + } + self.cuda_graphs[batch_size] = graph_dict + + def tunableop_warmup(self, batch_size: int, seqlen: int): + input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) + n_blocks = len(self.model.blocks) + + d_state = self.model.config.d_state + d_conv = self.model.config.d_conv + # Inner takes the expand multiplication + d_inner = self.model.config.d_inner + + # Important seqlen_offset to go through the update mecanism with the state + seqlen_offset = 1 + inference_params = new_inference_params( + n_blocks=n_blocks, + batch_size=seqlen, + d_state=d_state, + d_conv=d_conv, + d_inner=d_inner, + seqlen_offset=seqlen_offset, + device=self.device, + dtype=self.dtype, + ) + + self.model.forward(input_ids=input_ids, inference_params=inference_params) + + def forward( + self, input_ids: torch.Tensor, inference_params: Any + ) -> Tuple[torch.Tensor, torch.Tensor]: + bs = input_ids.shape[0] + padded_bs = bs + if bs == 3: + padded_bs = 4 + elif 3 < bs <= 8: + padded_bs = 8 + elif bs > 8: + padded_bs = (bs + 7) // 8 * 8 + + # Try to find an associated cuda graph + cuda_graph = self.cuda_graphs.get(padded_bs, None) + is_prefill = inference_params is None or inference_params.seqlen_offset == 0 + + if is_prefill or cuda_graph is None: + return self.model( + input_ids, + inference_params=inference_params, + ) + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][:bs] = input_ids + cuda_graph["inference_params"].conv_states[ + :, :bs + ] = inference_params.conv_states + cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states + + # Replay the graph + cuda_graph["graph"].replay() + + inference_params.conv_states.copy_( + cuda_graph["inference_params"].conv_states[:, :bs] + ) + inference_params.ssm_states.copy_( + cuda_graph["inference_params"].ssm_states[:, :bs] + ) + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits + + def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: + start = time.time_ns() + input_ids = ( + batch.input_ids + ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids + + batch_size, max_seqlen = input_ids.shape + # Inference params + + if batch.inference_params is None: + # 0 is important here + seqlen_offset = 0 + n_blocks = len(self.model.blocks) + d_state = self.model.config.d_state + d_conv = self.model.config.d_conv + d_inner = self.model.config.d_inner + inference_params = new_inference_params( + n_blocks=n_blocks, + batch_size=batch_size, + d_state=d_state, + d_conv=d_conv, + d_inner=d_inner, + seqlen_offset=seqlen_offset, + device=self.device, + dtype=self.dtype, + ) + batch.inference_params = inference_params + + # Forward pass + logits, speculative_logits = self.forward( + input_ids, inference_params=batch.inference_params + ) + + # batch.inference_params = new_inference_params + # Results + generations: List[Generation] = [] + stopped = True + + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.log_softmax(logits[:, -1], -1), + accepted_ids, + ) + + start_decode = time.time_ns() + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[ + i + ].advance_grammar(next_token_id_squeezed.item()) + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py new file mode 100644 index 000000000..9e19e1715 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -0,0 +1,357 @@ +from io import BytesIO +from PIL import Image +import torch +from typing import Iterable, Optional, Tuple, List, Dict +from text_generation_server.pb.generate_pb2 import Request + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import ( + PreTrainedTokenizerBase, +) +from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.flash_causal_lm import ( + block_tables_to_ragged, +) +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from text_generation_server.layers.attention import Seqlen + + +tracer = trace.get_tracer(__name__) + + +@dataclass +class MllamaCausalLMBatch(VlmCausalLMBatch): + image_indices: List[int] = 42 + aspect_ratio_ids: Optional[torch.Tensor] = None + aspect_ratio_mask: Optional[torch.Tensor] = None + cross_attention_states: Optional[torch.Tensor] = None + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super().concatenate(batches) + batch.pixel_values = None + batch.pixel_attention_mask = None + + offset = 0 + image_indices = [] + attention_states = [] + for b in batches: + if b.cross_attention_states is not None: + attention_states.append(b.cross_attention_states) + image_indices.extend([i + offset for i in b.image_indices]) + offset += len(b.image_indices) + if len(attention_states) > 0: + assert len(image_indices) > 0 + batch.cross_attention_states = torch.cat(attention_states, dim=0) + batch.image_indices = image_indices + else: + batch.cross_attention_states = None + batch.image_indices = [] + return batch + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + assert self.image_indices is not None + batch = super().filter(request_ids) + assert self.image_indices is not None + indices = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + offset = 0 + new_image_indices = [] + prev_i = None + for i in self.image_indices: + if i in indices: + new_image_indices.append(offset) + if i != prev_i: + offset += 1 + prev_i = i + + batch.image_indices = new_image_indices + if len(new_image_indices) > 0: + assert max(new_image_indices) < self.cross_attention_states.shape[0] + assert offset <= self.cross_attention_states.shape[0] + batch.cross_attention_states = self.cross_attention_states[ + new_image_indices + ] + else: + batch.cross_attention_states = None + return batch + + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[Request], tokenizer, processor, config + ): + image_inputs = [] + texts = [] + image_indices = [] + batch_tokenized_inputs = [] + + for i, r in enumerate(requests): + # Each input is encoded into a list, where each element of this input list is either a string or a URL + curr_text = "" + curr_image = None + curr_i = None + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + curr_text += chunk.text + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + # TODO unsure about BOS + curr_text += "<|image|>" + image_input = processor.image_processor(image, return_tensors="pt") + curr_image = image_input + curr_i = i + # image_inputs.append(image_input) + # image_indices.append(i) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + texts.append(curr_text) + if curr_image is not None: + image_inputs.append(curr_image) + image_indices.append(curr_i) + + input_ids = tokenizer( + curr_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + batch_tokenized_inputs.append(input_ids) + if image_inputs: + image_input = image_inputs[0] + new_image_inputs = { + "pixel_values": torch.cat( + [img["pixel_values"] for img in image_inputs], dim=0 + ), + } + if "aspect_ratio_ids" in image_input: + new_image_inputs["aspect_ratio_ids"] = torch.cat( + [img["aspect_ratio_ids"] for img in image_inputs], dim=0 + ) + if "aspect_ratio_mask" in image_input: + new_image_inputs["aspect_ratio_mask"] = torch.cat( + [img["aspect_ratio_mask"] for img in image_inputs], dim=0 + ) + image_inputs = new_image_inputs + image_inputs["image_indices"] = image_indices + else: + image_inputs = None + + if image_inputs is not None: + assert len(image_indices) == image_inputs["pixel_values"].shape[0] + + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + config, + dtype: torch.dtype, + device: torch.device, + ) -> "VlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor, config + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + # XXX: <|image|> token is actually out of bounds and bugs out the logit processors. + batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( + max=config.text_config.vocab_size - 1 + ) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) + + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to( + device=device, dtype=dtype + ) + batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device) + batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to( + device=device + ) + batch.image_indices = image_inputs["image_indices"] + else: + batch.pixel_values = None + batch.aspect_ratio_ids = None + batch.aspect_ratio_mask = None + batch.image_indices = [] + assert batch.image_indices is not None + return batch + + +class MllamaCausalLM(VlmCausalLM): + def forward( + self, + batch: VlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + # Try to find an associated cuda graph + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + if ( + cu_seqlen_prefill is not None + or cuda_graph is None + # Only run cuda graphs when there's no images. + or batch.cross_attention_states is not None + ): + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths_tensor=input_lengths, + prefix_lens_tensor=prefix_lens_tensor, + ): + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) + + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, + ) + batch.cross_attention_states = cross_attention_states + + cross_attention_states = batch.cross_attention_states + + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + cross_attention_states=cross_attention_states, + adapter_data=adapter_data, + image_indices=batch.image_indices[:], + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + return logits, speculative_logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"].fill_(0) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py new file mode 100644 index 000000000..4568b8dde --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/model.py @@ -0,0 +1,137 @@ +import inspect +from loguru import logger +import torch + +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional, TypeVar, Type, Dict +from collections import defaultdict +from transformers import PreTrainedTokenizerBase + +from text_generation_server.models.types import Batch, Generation +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.pb.generate_pb2 import InfoResponse +from text_generation_server.adapters.weights import LayerAdapterWeights +import time +BASE_MODEL_ADAPTER_ID = "__base_model__" + + +B = TypeVar("B", bound=Batch) + + +class Model(ABC): + def __init__( + self, + model_id: str, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + requires_padding: bool, + dtype: torch.dtype, + device: torch.device, + rank: int = 0, + world_size: int = 1, + sliding_window: Optional[int] = None, + speculate: Optional[int] = None, + adapter_id: str = BASE_MODEL_ADAPTER_ID, + ): + self.model_id = model_id + self.model = model.eval() + self.tokenizer = tokenizer + + # all_special_ids is not set correctly if the rust tokenizer is unpacked + # TODO report this to transformers. + other_special_ids = { + id for id, token in tokenizer.added_tokens_decoder.items() if token.special + } + self.all_special_ids = set(tokenizer.all_special_ids) + self.all_special_ids.update(other_special_ids) + self.requires_padding = requires_padding + self.dtype = dtype + self.device = device + self.rank = rank + self.world_size = world_size + self.sliding_window = sliding_window if sliding_window != -1 else None + + self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( + LayerAdapterWeights + ) + self.loaded_adapters = set() + self.static_adapter_id = adapter_id + + if speculate is None: + speculate = get_speculate() + self.speculate = speculate + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + + self.check_initialized() + + @property + def info(self) -> InfoResponse: + if self.requires_padding and self.sliding_window is not None: + raise NotImplementedError("sliding_window is not implemented with padding") + + return InfoResponse( + requires_padding=self.requires_padding, + dtype=str(self.dtype), + device_type=self.device.type, + window_size=self.sliding_window, + speculate=self.speculate, + ) + + @property + @abstractmethod + def batch_type(self) -> Type[B]: + raise NotImplementedError + + @abstractmethod + def generate_token( + self, batch: B + ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: + raise NotImplementedError + + def warmup(self, batch: B) -> Optional[int]: + self.generate_token(batch) + return None + + def decode_token( + self, + all_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + skip_special_tokens: bool = False, + ) -> Tuple[str, int, int]: + """Hack to hopefully support generate_stream for the maximum number of tokenizers""" + + # The prefix text is necessary only to defeat cleanup algorithms in the decode + # which decide to add a space or not depending on the surrounding ids. + prefix_text = self.tokenizer.decode( + all_input_ids[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + ) + + new_text = self.tokenizer.decode( + all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens + ) + + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + new_text = new_text[len(prefix_text) :] + return new_text, read_offset, len(all_input_ids) + else: + return "", prefix_offset, read_offset + + def check_initialized(self): + uninitialized_parameters = [] + for n, p in self.model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" + ) diff --git a/backends/gaudi/server/text_generation_server/models/pali_gemma.py b/backends/gaudi/server/text_generation_server/models/pali_gemma.py new file mode 100644 index 000000000..fe75570ea --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/pali_gemma.py @@ -0,0 +1,71 @@ +from io import BytesIO +from PIL import Image +import torch +import torch.distributed +from opentelemetry import trace +from typing import Iterable +from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLMBatch, + image_text_replacement, +) + +from text_generation_server.pb.generate_pb2 import Request + +tracer = trace.get_tracer(__name__) + + +class PaliGemmaBatch(VlmCausalLMBatch): + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[Request], tokenizer, processor, config + ): + batch_inputs = [] + image_inputs = [] + max_truncation = 0 + for r in requests: + full_text = "" + image_id = 0 + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += "" + chunk.text + "\n" + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + # TODO do_convert_RGB should be on by default ? + image = image.convert("RGB") + image_input = processor.image_processor(image, return_tensors="pt") + full_text += image_text_replacement( + processor, image_input, config, image_id + ) + image_inputs.append(image_input) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + batch_inputs.append(full_text) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, + truncation=True, + max_length=max_truncation, + add_special_tokens=False, + )["input_ids"] + if image_inputs: + image_input = image_inputs[0] + new_image_inputs = { + "pixel_values": torch.cat( + [img["pixel_values"] for img in image_inputs], dim=0 + ), + } + if "pixel_attention_mask" in image_input: + new_image_inputs["pixel_attention_mask"] = torch.cat( + [img["pixel_attention_mask"] for img in image_inputs], dim=0 + ) + if "image_sizes" in image_input: + new_image_inputs["image_sizes"] = torch.cat( + [img["image_sizes"] for img in image_inputs], dim=0 + ) + image_inputs = new_image_inputs + else: + image_inputs = None + return batch_tokenized_inputs, image_inputs diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py new file mode 100644 index 000000000..04d4c28ba --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py @@ -0,0 +1,932 @@ +import torch +import torch.distributed +import time +from dataclasses import dataclass +from opentelemetry import trace +from transformers import ( + AutoTokenizer, + AutoModelForSeq2SeqLM, + PreTrainedTokenizerBase, + AutoConfig, +) +from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader +from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.models import Model +from text_generation_server.models.types import ( + GeneratedText, + Batch, + Generation, + Tokens, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling + +tracer = trace.get_tracer(__name__) + + +@dataclass +class Seq2SeqLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Encoder values + input_ids: Optional[torch.Tensor] + attention_mask: torch.Tensor + + # Decoder values + decoder_input_ids: torch.Tensor + decoder_attention_mask: Optional[torch.Tensor] + encoder_last_hidden_state: Optional[torch.Tensor] + + # All tokens + all_decoder_input_ids: List[torch.Tensor] + + # Seq2SeqLM keeps track of both encoder and decoder attention keys and values + past_key_values: Optional[List[Tuple]] + + # Lengths of all generations present in the batch + input_lengths: List[int] + decoder_input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + # Metadata used for padding + max_input_length: int + max_decoder_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + def to_pb(self) -> generate_pb2.CachedBatch: + """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf""" + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "Seq2SeqLMBatch": + """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch""" + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + decoder_input_lengths = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + inputs.append(concat_text_chunks(r.input_chunks.chunks)) + requests_idx_mapping[r.id] = i + decoder_input_lengths.append(1) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + # Tokenize batch + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + + # Decoder sequence only contains the bos_token + decoder_input_ids = ( + torch.tensor(tokenizer.bos_token_id, device=device) + .repeat(len(pb.requests)) + .view(-1, 1) + ) + for _ in pb.requests: + prefix_offsets.append(0) + read_offsets.append(1) + all_decoder_input_ids = decoder_input_ids.view(-1).split(1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=tokenized_inputs["input_ids"], + attention_mask=tokenized_inputs["attention_mask"], + decoder_input_ids=decoder_input_ids, + all_decoder_input_ids=list(all_decoder_input_ids), + decoder_attention_mask=None, + encoder_last_hidden_state=None, + past_key_values=None, + input_lengths=input_lengths.tolist(), + decoder_input_lengths=decoder_input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length.item(), + max_decoder_input_length=1, + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + decoder_input_lengths = [] + prefix_offsets = [] + read_offsets = [] + + all_decoder_input_ids = [] + + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + + max_input_length = 0 + max_decoder_input_length = 0 + padding_right_offset = 0 + + total_remaining_decode_tokens = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + + all_decoder_input_ids.append(self.all_decoder_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + request_decoder_input_length = self.decoder_input_lengths[idx] + decoder_input_lengths.append(request_decoder_input_length) + max_decoder_input_length = max( + max_decoder_input_length, request_decoder_input_length + ) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + padding_right_offset = max(padding_right_offset, remaining_decode_tokens) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + self.decoder_input_ids = self.decoder_input_ids[keep_indices] + self.attention_mask = self.attention_mask[keep_indices, -max_input_length:] + if self.decoder_attention_mask is not None: + self.decoder_attention_mask = self.decoder_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_decoder_input_length) : ( + self.decoder_attention_mask.shape[1] - self.padding_right_offset + ) + + padding_right_offset, + ] + + self.encoder_last_hidden_state = self.encoder_last_hidden_state[ + keep_indices, -max_input_length: + ] + + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) is tuple: + self.past_key_values = [ + [t for t in layer] for layer in self.past_key_values + ] + + decoder_past_seq_len = max_decoder_input_length - 1 + for layer in self.past_key_values: + layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:] + layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:] + layer[2] = layer[2][keep_indices, :, -max_input_length:] + layer[3] = layer[3][keep_indices, :, -max_input_length:] + + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] + max_tokens = ( + len(request_ids) * (max_input_length + max_decoder_input_length) + + remaining_decode_tokens + ) + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = None + self.all_decoder_input_ids = all_decoder_input_ids + self.input_lengths = input_lengths + self.decoder_input_lengths = decoder_input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor + self.max_input_length = max_input_length + self.max_decoder_input_length = max_decoder_input_length + self.padding_right_offset = padding_right_offset + self.max_tokens = max_tokens + + return self + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": + """Concatenate multiple batches together by padding internal torch tensors""" + + # Used for padding + total_batch_size = 0 + max_input_length = 0 + max_decoder_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + max_decoder_input_length = max( + max_decoder_input_length, batch.max_decoder_input_length + ) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + all_decoder_input_ids = [] + input_lengths = [] + decoder_input_lengths = [] + prefix_offsets = [] + read_offsets = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + max_tokens = 0 + + # Batch tensors + attention_mask = None + decoder_input_ids = None + decoder_attention_mask = None + encoder_last_hidden_state = None + top_n_tokens_tensor = None + past_key_values = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + + for i, batch in enumerate(batches): + # Extend all list attributes + requests.extend(batch.requests) + all_decoder_input_ids.extend(batch.all_decoder_input_ids) + input_lengths.extend(batch.input_lengths) + decoder_input_lengths.extend(batch.decoder_input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.encoder_last_hidden_state is None: + raise ValueError("Batch encoder_last_hidden_state cannot be None") + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length), + ) + # Copy to correct indices + attention_mask[start_index:end_index, -batch.max_input_length :] = ( + batch.attention_mask[:, -batch.max_input_length :] + ) + + # Create padded tensor + if decoder_input_ids is None: + decoder_input_ids = batch.decoder_input_ids.new_zeros( + (total_batch_size, 1), + ) + # Copy to correct indices + decoder_input_ids[start_index:end_index] = batch.decoder_input_ids + + # Create padded tensor + if decoder_attention_mask is None: + # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here + decoder_attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_decoder_input_length + padding_right_offset), + ) + # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated + # this batch. All generations are of length `batch.max_decoder_input_length`. + left_offset = max_decoder_input_length - batch.max_decoder_input_length + if batch.decoder_attention_mask is None: + decoder_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = 1 + # If it exists, we need to index + else: + batch_left_offset = ( + batch.decoder_attention_mask.shape[1] + - batch.max_decoder_input_length + - batch.padding_right_offset + ) + decoder_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = batch.decoder_attention_mask[ + :, + batch_left_offset : -batch.padding_right_offset, + ] + + # Create padded tensor + if encoder_last_hidden_state is None: + encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros( + ( + total_batch_size, + max_input_length, + batch.encoder_last_hidden_state.shape[-1], + ), + ) + + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + # Copy to correct indices + encoder_last_hidden_state[ + start_index:end_index, -batch.max_input_length :, : + ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :] + batch.encoder_last_hidden_state = None + + # Ensure that we can update tensors in-place + if isinstance(batch.past_key_values[0], tuple): + batch.past_key_values = [ + [t for t in layer] for layer in batch.past_key_values + ] + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length + - batch.max_input_length + + max_decoder_input_length + - batch.max_decoder_input_length + ) * len(batch) + + start_index = end_index + + # Determine shapes for new past kv tensors + first_past_kvs = batches[0].past_key_values + _, num_heads, _, head_dim = first_past_kvs[0][0].shape + + padded_dec_t_shape = ( + total_batch_size, + num_heads, + (max_decoder_input_length - 1), + head_dim, + ) + + padded_enc_t_shape = ( + total_batch_size, + num_heads, + max_input_length, + head_dim, + ) + + # Iterate over attention layers + for j in range(len(first_past_kvs)): + past_key_values.append([]) + + # Decoder past + for k in range(0, 2): + # Initialize tensors + padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape) + past_key_values[j].append(padded_past_values) + + start_index = 0 + for batch in batches: + t = batch.past_key_values[j][k] + # Clear reference to the original tensor + batch.past_key_values[j][k] = None + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past keys and values to remove the padding from previous batches + past_seq_len = batch.max_decoder_input_length - 1 + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[ + :, :, -past_seq_len:, : + ] + del t + + start_index = end_index + + # Encoder past + for k in range(2, 4): + # Initialize tensors + padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape) + past_key_values[j].append(padded_past_values) + + start_index = 0 + for batch in batches: + t = batch.past_key_values[j][k] + # Clear reference to the original tensor + batch.past_key_values[j][k] = None + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past keys and values to remove the padding from previous batches + padded_past_values[ + start_index:end_index, :, -batch.max_input_length :, : + ] = t[:, :, -batch.max_input_length :, :] + del t + + start_index = end_index + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=None, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + all_decoder_input_ids=all_decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_last_hidden_state=encoder_last_hidden_state, + past_key_values=past_key_values, + input_lengths=input_lengths, + decoder_input_lengths=decoder_input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length, + max_decoder_input_length=max_decoder_input_length, + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + def __len__(self): + return len(self.requests) + + +class Seq2SeqLM(Model): + def __init__( + self, + model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + config_class=AutoConfig, + tokenizer_class=AutoTokenizer, + aliases=None, + ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + tokenizer.bos_token_id = config.decoder_start_token_id + + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + model = AutoModelForSeq2SeqLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map=( + "auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None + ), + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + tokenizer.bos_token_id = model.config.decoder_start_token_id + + self = cls.__new__( + cls, + ) + super().__init__( + self, + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + self.quantize = quantize + return self + + @property + def batch_type(self) -> Type[Seq2SeqLMBatch]: + return Seq2SeqLMBatch + + def forward( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask: Optional, + encoder_last_hidden_state: Optional, + past_key_values: Optional = None, + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + ]: + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_last_hidden_state, + past_key_values=past_key_values, + use_cache=True, + ) + if isinstance(outputs, tuple): + # Our custom models + outputs, speculative_logits = outputs + else: + # Generic transformers models + speculative_logits = None + return ( + outputs.logits, + speculative_logits, + outputs.encoder_last_hidden_state, + outputs.past_key_values, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: Seq2SeqLMBatch + ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]: + start = time.time_ns() + if batch.decoder_attention_mask is not None: + # slice to the correct shape + decoder_attention_mask = batch.decoder_attention_mask[ + :, : -batch.padding_right_offset + ] + else: + decoder_attention_mask = None + + # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]` + # internally... + if batch.encoder_last_hidden_state is not None: + encoder_last_hidden_state = [batch.encoder_last_hidden_state] + else: + encoder_last_hidden_state = None + + logits, speculative_logits, encoder_last_hidden_state, past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.decoder_input_ids, + decoder_attention_mask, + encoder_last_hidden_state, + batch.past_key_values, + ) + + # Speculation is not active for seq2seq + accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.log_softmax(logits[:, -1], -1), + accepted_ids, + ) + + start_decode = time.time_ns() + + # Finished requests + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + batch.decoder_input_lengths, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_decoder_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + decoder_input_length, + logits, + next_token_chooser, + stopping_criteria, + all_decoder_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_decoder_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to decoder tokens + all_decoder_input_ids = torch.cat( + [all_decoder_input_ids, next_token_id.squeeze(1)] + ) + new_decoder_input_length = decoder_input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_decoder_input_ids, prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria(next_token_id, next_token_text) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Slice with decoder_input_length to remove padding + # Decode all tokens + output_text, _, _ = self.decode_token( + all_decoder_input_ids, + prefix_offset=len(all_decoder_input_ids) + - decoder_input_length + - 1, + read_offset=len(all_decoder_input_ids) - decoder_input_length, + skip_special_tokens=True, + ) + + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + prefill_tokens = Tokens( + [self.tokenizer.bos_token_id], + [float("nan")], + [self.tokenizer.bos_token], + [False], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) + batch.decoder_input_ids[i] = next_token_id + batch.all_decoder_input_ids[i] = all_decoder_input_ids + batch.input_lengths[i] = input_length + batch.decoder_input_lengths[i] = new_decoder_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, input_length) + batch.max_decoder_input_length = max( + batch.max_decoder_input_length, new_decoder_input_length + ) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # We don't need input_ids after the prefill forward + batch.input_ids = None + batch.encoder_last_hidden_state = encoder_last_hidden_state + batch.past_key_values = past + # Update decoder_attention_mask as we added a new token to input_ids + if batch.decoder_attention_mask is not None: + batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1 + batch.padding_right_offset -= 1 + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/starcoder.py b/backends/gaudi/server/text_generation_server/models/starcoder.py new file mode 100644 index 000000000..98e7939a6 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/starcoder.py @@ -0,0 +1,47 @@ +from loguru import logger +import torch +from dataclasses import dataclass +import os +from typing import List, Optional, Type + +from text_generation_server.models import CausalLM +from text_generation_server.models.causal_lm import CausalLMBatch + + +@dataclass +class StarCoderCausalLMBatch(CausalLMBatch): + past_key_values: Optional[List[torch.Tensor]] + + def detach_kv_cache(self): + past_keys = [] + past_values = [] + last_dim = int(self.past_key_values[0].size(dim=-1)/2) + for key_value in self.past_key_values: + past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) + past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) + del self.past_key_values + + return past_keys, past_values + + def attach_kv_cache(self, past_keys, past_values): + self.past_key_values = [ + torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)] + + +class StarCoder(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + + super(StarCoder, self).__init__( + model_id=model_id, + revision=revision, + dtype=dtype, + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return StarCoderCausalLMBatch \ No newline at end of file diff --git a/backends/gaudi/server/text_generation_server/models/types.py b/backends/gaudi/server/text_generation_server/models/types.py new file mode 100644 index 000000000..d4e7cca75 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/types.py @@ -0,0 +1,102 @@ +import torch + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +from transformers import PreTrainedTokenizerBase + +from text_generation_server.pb import generate_pb2 +from text_generation_server.pb.generate_pb2 import FinishReason + + +class Batch(ABC): + @abstractmethod + def to_pb(self) -> generate_pb2.CachedBatch: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "Batch": + raise NotImplementedError + + @abstractmethod + def filter(self, request_ids: List[int]) -> "Batch": + raise NotImplementedError + + @classmethod + @abstractmethod + def concatenate(cls, batches: List["Batch"]) -> "Batch": + raise NotImplementedError + + @abstractmethod + def __len__(self): + raise NotImplementedError + + +@dataclass +class GeneratedText: + text: str + generated_tokens: int + finish_reason: FinishReason + seed: Optional[int] + + def to_pb(self) -> generate_pb2.GeneratedText: + return generate_pb2.GeneratedText( + text=self.text, + generated_tokens=self.generated_tokens, + finish_reason=self.finish_reason, + seed=self.seed, + ) + + +@dataclass +class Tokens: + token_ids: List[int] + logprobs: List[float] + texts: List[str] + is_special: List[bool] + + def to_pb(self) -> generate_pb2.Tokens: + return generate_pb2.Tokens( + ids=self.token_ids, + logprobs=self.logprobs, + texts=self.texts, + is_special=self.is_special, + ) + + def __len__(self): + return len(self.token_ids) + + +@dataclass +class Generation: + request_id: int + prefill_tokens: Optional[Tokens] + tokens: Tokens + generated_text: Optional[GeneratedText] + # Optional for now, since it's not yet supported for every model. + top_tokens: Optional[List[Tokens]] + + def to_pb(self) -> generate_pb2.Generation: + return generate_pb2.Generation( + request_id=self.request_id, + prefill_tokens=( + self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None + ), + tokens=self.tokens.to_pb(), + generated_text=( + self.generated_text.to_pb() if self.generated_text is not None else None + ), + top_tokens=( + [top_tokens.to_pb() for top_tokens in self.top_tokens] + if self.top_tokens is not None + else None + ), + ) diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py new file mode 100644 index 000000000..5c9955c24 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -0,0 +1,1227 @@ +import re +import torch +import os +import time +import math +from PIL import Image +from io import BytesIO +import base64 +import numpy +from opentelemetry import trace +from loguru import logger +from typing import Iterable, Optional, Tuple, List, Type, Dict +import itertools +import tempfile +import copy +from text_generation_server.models import Model +from transformers import PreTrainedTokenizerBase +from transformers.image_processing_utils import select_best_resolution +from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import ( + CausalLMBatch, + CausalLMRequest, + remove_kv_cache_from_output, + biggest_single_chunk, +) + +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, +) + +from transformers import AutoProcessor +import text_generation_server.habana_quantization_env as hq_env +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from text_generation_server.utils import ( + HeterogeneousNextTokenChooser, + StoppingCriteria, + make_tokenizer_optional, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, +) +import habana_frameworks.torch as htorch +from optimum.habana.utils import HabanaProfile +from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES +from optimum.habana.utils import get_hpu_memory_stats +from optimum.habana.checkpoint_utils import get_ds_injection_policy + +from transformers import ( + AutoTokenizer, + AutoModel, + PreTrainedTokenizerBase, + AutoConfig, +) +from optimum.habana.checkpoint_utils import ( + get_repo_root, + model_on_meta, + write_checkpoints_json, +) + +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils.debug import dbg_trace + +tracer = trace.get_tracer(__name__) + +IDEFICS2_FAKE_TOKEN = "" +IDEFICS2_IMAGE_TOKEN = "" + + +IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") +BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) +MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) +MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) +CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) + +PREFILL_WARMUP_BATCH_SIZE_LIST = [] +PREFILL_WARMUP_SEQLEN_LIST = [] +DECODE_WARMUP_BATCH_SIZE_LIST = [] +def round_up(warmup_list:list, num) : + i = 0 + for i in warmup_list: + if num <= i : + break + return i + +def split(string) -> List[Dict[str, str]]: + parts = [] + cursor = 0 + for pattern in IMAGES.finditer(string): + start = pattern.start() + if start != cursor: + parts.append({"type": "text", "content": string[cursor:start]}) + + parts.append({"type": "image", "content": pattern.group(1)}) + cursor = pattern.end() + + if cursor != len(string): + parts.append({"type": "text", "content": string[cursor:]}) + + return parts + +def image_text_replacement(processor, image_input, config, image_id: int) -> str: + if config.model_type == "idefics2": + image_seq_len = 64 + image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" + if processor.image_processor.do_image_splitting: + image_str *= 5 + return image_str + elif config.model_type == "llava_next": + height, width = image_input["image_sizes"][image_id] + num_features = get_number_of_features(height, width, config) + from loguru import logger + return "" * num_features + + elif config.model_type == "paligemma": + return "" * config.text_config.num_image_tokens + else: + raise RuntimeError(f"Unknown config {config.model_type} for multimodal") + + +def image_text_replacement_fixup(config, text: str) -> str: + if config.model_type == "idefics2": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) + return text + + +def get_unpadded_features( + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, +) -> Tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio: float = original_width / original_height + current_aspect_ratio: float = current_width / current_height + + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) + else: + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + +def get_number_of_features(height: int, width: int, config) -> int: + # From config + # Hardcoded for CLIP for now + # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + image_grid_pinpoints = config.image_grid_pinpoints + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + + assert image_size % patch_size == 0 + + npatches = image_size // patch_size + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + [height, width], + image_grid_pinpoints, + image_size, + ) + + unpadded_features, newline_features = get_unpadded_features( + height, width, npatches, num_patch_height, num_patch_width + ) + # The base patch covers the entire image + base_features = npatches**2 + return unpadded_features + newline_features + base_features + + +class VlmCausalLMBatch(CausalLMBatch): + pixel_values: Optional[List[torch.Tensor]] + pixel_attention_mask: Optional[List[torch.Tensor]] + image_sizes: Optional[List[Tuple[int, int]]] + + @classmethod + def from_tokenized( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + batch_tokenized_inputs, + dtype: torch.dtype, + device: torch.device, + is_warmup: bool = False, + ) -> "VlmCausalLMBatch": + + dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') + requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] + + max_input_length = max(r.data.truncate for r in requests) + max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + # TODO: Add support for sparse batches + top_n_tokens = [r.top_n_tokens for r in pb.requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + # TODO: by tokenizing all inputs at once we loose information on actual input lengths + # this means that we cannot shift inputs to the left after a long input sequence + # was filtered out + new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) + parameters = [r.parameters for r in pb.requests] + # append the dummy parameters for dummy request + parameters = pad_next_token_chooser_parameters(parameters, new_bs) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + pb=parameters, + dtype=dtype, + device=device, + tokenizer=tokenizer, + quantization_enabled=hq_env.is_quantization_enabled, + ) + tokenized_inputs = batch_tokenized_inputs + input_len = tokenized_inputs["input_ids"].shape[1] + + bucket_size = max_input_length + left_padding = max_input_length - input_len + if is_warmup is False: + if input_len < max_input_length : + rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 + left_padding = bucket_size - input_len + + input_ids = tokenized_inputs["input_ids"] + attention_mask = tokenized_inputs["attention_mask"] + # Allocate space for first token + if left_padding > 0: + input_ids = torch.nn.functional.pad( + input_ids, (left_padding, 1), value=tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (left_padding, 1), value=0 + ) + all_input_ids = torch.nn.functional.pad( + input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id + ).T.split(1, dim=1) + + # New input length after left padding + input_len = bucket_size + for r in requests: + r.input_length = input_len + r.prefix_offset = input_len - 5 + r.read_offset = input_len + r.all_input_ids = all_input_ids[r.idx] + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + htorch.core.mark_step() + + return cls( + batch_id=pb.id, + requests=requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_len, + ) + + + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup + ): + # Process images first. We need all of them so that the processor + # can make the image splits the same size. And we need the final + # sizes to insert correct number of image tokens. + images = [] + for r in requests: + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + pass + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + if config.model_type == "llava_next": + images.append(image) + else: + images.append([image]) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + image_inputs = None + if images: + image_inputs = processor.image_processor(images, return_tensors="pt") + + batch_inputs = [] + max_truncation = 0 + image_id = 0 + for r in requests: + full_text = "" + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += chunk.text + elif chunk_type == "image": + full_text += image_text_replacement( + processor, image_inputs, config, image_id + ) + image_id += 1 + full_text = image_text_replacement_fixup(config, full_text) + + batch_inputs.append(full_text) + max_truncation = max(max_truncation, r.truncate) + + missing_inputs = 0 + dummy_images = None + if is_warmup is False: + new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) + missing_inputs = new_bs - len(requests) + if missing_inputs > 0: + dummy_inputs = [] + if len(batch_inputs) > 0: + dummy_inputs = [batch_inputs[0]] * missing_inputs + + batch_inputs += dummy_inputs + + batch_tokenized_inputs = tokenizer( + batch_inputs, + truncation=True, + max_length=max_truncation, + add_special_tokens=not config.model_type == "paligemma", + return_tensors="pt", + padding="longest", + return_token_type_ids=False, + ) + + if missing_inputs > 0 and image_inputs is not None: + dummy_shape = list(image_inputs['pixel_values'].shape) + dummy_shape[0] = missing_inputs + dummy_images = torch.rand(dummy_shape) + new_image_inputs = { + "pixel_values": torch.cat( + (image_inputs['pixel_values'], dummy_images), dim=0 + ), + } + if "pixel_attention_mask" in image_inputs: + dummy_shape = list(image_inputs['pixel_attention_mask'].shape) + dummy_shape[0] = missing_inputs + dummy_attention = torch.zeros(dummy_shape) + new_image_inputs["pixel_attention_mask"] = torch.cat( + (image_inputs["pixel_attention_mask"], dummy_attention), dim=0 + ) + if "image_sizes" in image_inputs: + dummy_shape = list(list(image_inputs['image_sizes'])[0]) + dummy_shape = missing_inputs*[dummy_shape] + dummy_sizes = torch.IntTensor(dummy_shape) + new_image_inputs["image_sizes"] = torch.cat( + (image_inputs["image_sizes"], dummy_sizes), dim=0 + ) + image_inputs = new_image_inputs + + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + config, + dtype: torch.dtype, + device: torch.device, + is_warmup: bool = False, + ) -> "VlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor, config, is_warmup + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to(device=device) + if "pixel_attention_mask" in image_inputs: + batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( + device=device + ) + else: + batch.pixel_attention_mask = None + if "image_sizes" in image_inputs: + batch.image_sizes = image_inputs["image_sizes"].to(device=device) + else: + batch.image_sizes = None + else: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch": + return cls.recombine(batches, pad_token_id, is_warmup) + + + + @classmethod + def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch": + if not all(b.past_key_values is not None for b in batches): + raise ValueError("KV cache not allocated! Cannot recombine before prefill!") + + total_requests = sum(len(b) for b in batches) + new_bs = total_requests + if is_warmup is False : + new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) + batch_id = batches[0].batch_id + device = batches[0].input_ids.device + + input_lengths = [b.input_length for b in batches] + max_input_length = max(input_lengths) + offsets = [max_input_length - b.input_length for b in batches] + + cur_padding = [b.right_padding for b in batches] + # For prefill there is a space allocated only for first token + # Need to add padding to the max total tokens before first decode + + moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] + dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] + reshape = (batches[dst_batch_idx].batch_size < new_bs) + + # TODO: Add support for changing max seq len, i.e. due to output length bucketing + # FIXME: max_seq_len for non optimized code + if len(batches) > 1: + scenario = 'CONCAT' + elif reshape: + scenario = 'RESHAPE' + elif cur_padding[dst_batch_idx] <= 0: + scenario = 'SHIFT' + offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] + max_input_length = max_input_length + offsets[dst_batch_idx] + else: + # Nothing to do + return batches[0] + + dbg_trace( + scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' + f' reqs:{[len(b) for b in batches]}' + f' offsets:{offsets}' + f' input_lengths:{input_lengths}' + f' cur_padding:{cur_padding}' + f' dst_batch:{dst_batch_idx}') + + grouped_requests = [[req for req in batch.requests] for batch in batches] + flat_requests = list(itertools.chain(*grouped_requests)) + + for i in range(len(batches)): + target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size + batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) + batches[i].realign(target_bs, offsets[i], pad_token_id) + batches[i].split_kv_cache_if_needed(i == dst_batch_idx) + batches[dst_batch_idx].expand_bs(new_bs) + batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) + + top_n_tokens = [r.data.top_n_tokens for r in flat_requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + parameters = [r.data.parameters for r in flat_requests] + # append the dummy parameters for dummy requests + batch_size = batches[dst_batch_idx].batch_size + parameters = pad_next_token_chooser_parameters(parameters, batch_size) + + # update past grammar states + fsm_grammar_states = [0] * batch_size + for batch in batches: + for i, req in enumerate(batch.requests): + fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + parameters, + batches[dst_batch_idx].next_token_chooser.dtype, + batches[dst_batch_idx].next_token_chooser.device, + batches[dst_batch_idx].next_token_chooser.tokenizer, + fsm_grammar_states, + quantization_enabled=hq_env.is_quantization_enabled, + ) + + input_ids = batches[dst_batch_idx].input_ids + attention_mask = batches[dst_batch_idx].attention_mask + position_ids = batches[dst_batch_idx].position_ids + past_key_values = batches[dst_batch_idx].past_key_values + input_length = max_input_length + + htorch.core.mark_step() + + return cls( + batch_id=batch_id, + requests=flat_requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_length, + ) + +class VlmCausalLM(Model): + def __init__( + self, + model_class, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=VlmCausalLMBatch, + revision, + quantize: Optional[str] = None, + dtype, + trust_remote_code: bool, + **kwargs, + ): + adapt_transformers_to_gaudi() + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + self.prev_bs = 0 + self.quantize = quantize + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + make_tokenizer_optional(tokenizer) + + # Create model + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("hpu") + + if hq_env.is_quantization_enabled: + htorch.core.hpu_set_env() + + if world_size > 1: + model = self.get_deepspeed_model( + model_class, model_id, dtype, revision + ) + model = hq_env.prepare_model_for_quantization(model) + else: + get_repo_root(model_id) + + # Check support for rope scaling + model_kwargs = {} + config = AutoConfig.from_pretrained( + model_id + ) + if hasattr(config, "rope_scaling"): + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + model = model_class.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + trust_remote_code=trust_remote_code, + **model_kwargs + ) + model = hq_env.prepare_model_for_quantization(model) + model = model.eval().to(device) + + self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + else: + if LAZY_MODE == 0: + # It is said that "keep_input_mutations" is safe for inference to be done + dbg_trace( + "TORCH COMPILE", f'Torch compiling of model') + model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) + + model = hq_env.setup_quantization(model) + + if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: + raise ValueError(f"Model type {model.config.model_type} is not supported!") + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + if isinstance(model.config.eos_token_id, int): + tokenizer.pad_token_id = model.config.eos_token_id + elif isinstance(model.config.eos_token_id, list): + tokenizer.pad_token_id = model.config.eos_token_id[0] + else: + raise ValueError( + f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" + ) + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.kwargs = { + "use_cache": True, + "return_dict": True, + } + + if model.config.model_type in ["llava_next"]: + self.kwargs["attn_softmax_bf16"] = True + self.kwargs["trim_logits"] = True + + if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": + self.kwargs["use_flash_attention"] = True + if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": + self.kwargs["flash_attention_recompute"] = True + + self.speculate = get_speculate() + super(VlmCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + ) + + # Create profiler + ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] + record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" + output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") + self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) + if self.profiling_steps > 0: + self.hb_profiler = HabanaProfile( + wait=self.profiling_wait_steps, + warmup=self.profiling_warmup_steps, + active=self.profiling_steps, + output_dir=output_dir, + record_shapes=record_shapes + ) + self.hb_profiler.start() + else: + self.hb_profiler = None + self.step = 0 + + + @property + def batch_type(self) -> Type[VlmCausalLMBatch]: + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) + + def get_deepspeed_model( + self, + model_class, + model_id: str, + dtype: torch.dtype, + revision: Optional[str] = None + ) -> torch.nn.Module: + import deepspeed + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + + world_size, rank, local_rank = initialize_distributed_hpu() + model_kwargs = { + "revision": revision + } + + # Initialize process(es) for DeepSpeed + deepspeed.init_distributed(dist_backend="hccl") + logger.info( + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + ) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + load_to_meta = model_on_meta(config) + + # Check support for rope scaling + if hasattr(config, "rope_scaling"): + config.rope_scaling = self.get_rope_scaling() + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + if load_to_meta: + # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load + with deepspeed.OnDevice(dtype=dtype, device="meta"): + model = model_class.from_config(config, torch_dtype=dtype) + else: + get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) + # TODO: revisit placement on CPU when auto-injection is possible + with deepspeed.OnDevice(dtype=dtype, device="cpu"): + model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) + model = model.eval() + + # Initialize the model + ds_inference_kwargs = {"dtype": dtype} + ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} + ds_inference_kwargs["enable_cuda_graph"] = False + ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config) + + if load_to_meta: + # model loaded to meta is managed differently + checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") + write_checkpoints_json(model_id, local_rank, checkpoints_json) + ds_inference_kwargs["checkpoint"] = checkpoints_json.name + model = deepspeed.init_inference(model, **ds_inference_kwargs) + + return model.module + + def get_rope_scaling(self) -> Optional[Dict]: + rope_scaling = os.getenv("ROPE_SCALING", None) + if rope_scaling is None: + return None + + rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) + return { + 'type': rope_scaling, 'factor': float(rope_factor) + } + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + def decode_token( + self, + all_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: + if is_tokenizer_transparent(self.tokenizer): + new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) + return new_text, read_offset, len(all_input_ids) + else: + return super().decode_token(all_input_ids, prefix_offset, read_offset) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + token_idx, + past_key_values: Optional[List[Tuple]] = None, + pixel_values: Optional[List[torch.Tensor]] = None, + image_sizes: Optional[List[Tuple[int, int]]] = None, + bypass_hpu_graph: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "token_idx": token_idx, + "pixel_values": pixel_values, + "image_sizes": image_sizes, + } + + hpu_kwargs = {} + # Optimum Habana got "lazy_mode" key-val only supported for llama type of models + if self.model.config.model_type == "llama" : + hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 + + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + if bypass_hpu_graph != None: + hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph + + kwargs.update(self.kwargs) + model_inputs = self.model.prepare_inputs_for_generation(**kwargs) + if past_key_values is not None: + return self.model.forward(**model_inputs, **hpu_kwargs) + else: + outputs = self.model.forward(**model_inputs, **hpu_kwargs) + return outputs.logits, outputs.past_key_values + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batches: List[VlmCausalLMBatch], is_warmup: bool = False + ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + # Results + generations: List[Generation] = [] + prev_batches = [] + requests_to_generate = [] + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + for batch_id, batch in enumerate(batches): + if batch.logits is not None: + logits = batch.logits + past = batch.past + prefill = batch.past_key_values is None + if prefill: + # no right padding for prefill + token_idx_scalar = batch.attention_mask.shape[-1] - 1 + token_idx = torch.tensor(token_idx_scalar).to(self.device) + else: + token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx = torch.tensor(token_idx_scalar).to(self.device) + + # Select next token + input_length = batch.input_length + if logits.shape[-2] > 1: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate + ) + else: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits.squeeze(-2), self.speculate + ) + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) + + prev_batches.append({ + 'next_token_ids': next_token_ids, + 'next_token_logprobs': next_token_logprobs, + }) + + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append({ + 'req': req, + 'prev_req_idx': req.idx, + 'batch_id': batch_id, + 'seed': batch.next_token_chooser.seeds[req_idx], + 'do_sample': batch.next_token_chooser.do_sample[req_idx], + 'top_n_tokens': batch.top_n_tokens[req_idx], + 'top_token_ids': batch_top_token_ids[req_idx], + 'top_token_logprobs': batch_top_token_logprobs[req_idx], + 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], + }) + + htorch.core.mark_step() + + # Add new token into input_ids + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask.index_fill_(1, token_idx, 1) + + # Adjust lengths + batch.input_length += 1 + + # Update position_ids + if prefill: + batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + else: + batch.position_ids += 1 + # Update past key values + if prefill: + batch.past_key_values = past + + htorch.core.mark_step() + + # Stage 2. Prepare new batch for speculative scheduling + if len(batches) > 1: + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) + else: + batch = batches[0] + + prefill = batch.past_key_values is None + + # Check if we need to do any bookkeeping first + if not prefill: + batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) + + scenario = 'PREFILL' if prefill else 'GENERATE' + if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: + self.model.clear_cache() + self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) + dbg_trace( + scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') + #assert batch.right_padding > 0, 'No more room for next token!' + + # Execute batch + if prefill: + # no right padding for prefill + token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + batch.logits, batch.past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + batch.pixel_values, + batch.image_sizes, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + ) + elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): + # Don't schedule next forward if max_new_tokens for all requests equals 1 + # - we've already generated the first and only needed token in the prefill phase + pass + else: + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + batch.logits = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + ) + + htorch.core.mark_step() + + start_decode = time.time_ns() + + # Stage 3. Finish and return previous generations + stopped = len(requests_to_generate) > 0 + for prev_batch in prev_batches: + prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() + prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() + htorch.core.mark_step() + + for req_data in requests_to_generate: + req = req_data['req'] + i = req_data['prev_req_idx'] + prev_batch_id = req_data['batch_id'] + assert len(prev_batches) > prev_batch_id + next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] + next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] + + request = req.data + input_length = req.input_length + prefix_offset = req.prefix_offset + read_offset = req.read_offset + do_sample = req_data['do_sample'] + seed = req_data['seed'] + stopping_criteria = req.stopping_criteria + all_input_ids = req.all_input_ids + next_token_id = next_token_ids_cpu[i] + next_token_logprob = next_token_logprobs[i] + top_n_tokens = req_data['top_n_tokens'] + top_token_ids = req_data['top_token_ids'] + top_token_logprobs = req_data['top_token_logprobs'] + grammar_state = req_data['grammar_state'] + + # Append next token to all tokens + all_input_ids[input_length] = next_token_id + new_input_length = input_length + 1 + + # Generated token + if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: + next_token_text = '' + else: + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[0:new_input_length, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + if is_tokenizer_transparent(self.tokenizer): + output_text = None + else: + output_text = self.decode( + all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + next_token_logprobs + prefill_token_ids = all_input_ids[0: new_input_length - 1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id], + [next_token_logprob], + [next_token_text], + [next_token_id in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single_with_past_state( + req.idx, next_token_id, grammar_state + ) + ) + + req.all_input_ids = all_input_ids + req.input_length = new_input_length + req.prefix_offset = prefix_offset + req.read_offset = read_offset + + htorch.core.mark_step() + self.step = self.step + 1 + if self.hb_profiler is not None: + if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: + self.hb_profiler.stop() + else: + self.hb_profiler.step() + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch if not stopped else None, (forward_ns, decode_ns) + + def batch_from_pb(self, batch, is_warmup): + return VlmCausalLMBatch.from_pb_processor( + batch, + self.tokenizer, + self.processor, + self.model.config, + self.dtype, + self.device, + is_warmup + ) + + def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): + batch = copy.deepcopy(request.batch) + for req in batch.requests: + req.truncate = seq_len + + for i in range(len(batch.requests) - batch_size): + batch.requests.pop() + + return self.batch_from_pb(batch, is_warmup) + + def warmup(self, request) -> None: + is_warmup = True + batch = self.batch_from_pb(request.batch, is_warmup) + + try: + # max prefill batch size warmup + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + except: + raise RuntimeError( + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST + max_input_length = batch.input_ids.shape[1] + max_prefill_batch_size = batch.input_ids.shape[0] + PREFILL_WARMUP_BATCH_SIZE_LIST = [] + batch_size = 1 + while batch_size <= max_prefill_batch_size: + PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : + PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) + + seq_len = BASE_IMAGE_TOKENS + PREFILL_WARMUP_SEQLEN_LIST = [] + i = 0 + while seq_len <= max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) + seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) + i += 1 + if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) + + #Prefill and decode warmup + DECODE_WARMUP_BATCH_SIZE_LIST = [] + prefill_batch = None + decode_batch = None + try: + for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : + for seq_len in PREFILL_WARMUP_SEQLEN_LIST : + batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) + + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + + except: + raise RuntimeError( + f"Not enough memory to handle following prefill and decode warmup." + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing prefill and decode warmup successfully.\n" + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats} " + ) + + max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) + batch_size = max_prefill_batch_size * 2 + # Decode warmup with bigger batch_size + try: + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: + batches = [] + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + while batch_size <= max_decode_batch_size: + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + batches.clear() + + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + + batches.clear() + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: + max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 + batch_size = max_decode_batch_size + for i in range(int(max_decode_batch_size / 2)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) + max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS + MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens + except : + raise RuntimeError( + f"Not enough memory to handle batch_size({batch_size}) decode warmup." + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"max_decode_batch_size is {max_decode_batch_size}" + f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing decode warmup successfully.\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats}" + ) + + return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file diff --git a/backends/gaudi/server/text_generation_server/pb/.gitignore b/backends/gaudi/server/text_generation_server/pb/.gitignore new file mode 100644 index 000000000..5a68d6313 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/pb/.gitignore @@ -0,0 +1,3 @@ +*.py +*.pyi +*.py-e diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py new file mode 100644 index 000000000..159e6af1f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/server.py @@ -0,0 +1,278 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import asyncio +import os +import sys +import torch +import time +import signal + +from grpc import aio +from loguru import logger + +from grpc_reflection.v1alpha import reflection +from pathlib import Path +from typing import List, Optional + +from text_generation_server.cache import Cache +from text_generation_server.interceptor import ExceptionInterceptor +from text_generation_server.models import Model, get_model_with_lora_adapters +from text_generation_server.pb import generate_pb2_grpc, generate_pb2 +from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor +from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_adapter_to_index +from text_generation_server.utils.adapter import AdapterInfo +from text_generation_server.utils.tokens import make_tokenizer_optional + +try: + from text_generation_server.models.pali_gemma import PaliGemmaBatch + from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLMBatch, + ) + from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + + VLM_BATCH_TYPES = { + PaliGemmaBatch, + VlmCausalLMBatch, + IdeficsCausalLMBatch, + } +except (ImportError, NotImplementedError): + # These imports can fail on CPU/Non flash. + VLM_BATCH_TYPES = set() +from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION + + +class SignalHandler: + KEEP_PROCESSING = True + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully: Signal {signum}") + self.KEEP_PROCESSING = False + + +class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): + def __init__( + self, + model: Model, + cache: Cache, + server_urls: List[str], + ): + self.cache = cache + self.model = model + # Quantize is resolved during model loading + self.quantize = model.quantize + self.server_urls = server_urls + # For some reason, inference_mode does not work well with GLOO which we use on CPU + # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul + # op not optimized issue. Will investigate further. + # if model.device.type == "hpu": + # Force inference mode for the lifetime of TextGenerationService + # self._inference_mode_raii_guard = torch._C._InferenceMode(True) + + + async def Info(self, request, context): + return self.model.info + + async def Health(self, request, context): + if self.model.device.type == "hpu": + torch.zeros((2, 2)).to("hpu") + return generate_pb2.HealthResponse() + + async def ServiceDiscovery(self, request, context): + return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) + + async def ClearCache(self, request, context): + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() + return generate_pb2.ClearCacheResponse() + + async def FilterBatch(self, request, context): + batch = self.cache.pop(request.batch_id) + if batch is None: + raise ValueError(f"Batch ID {request.batch_id} not found in cache.") + filtered_batch = batch.filter(request.request_ids) + self.cache.set(filtered_batch) + + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + + async def Warmup(self, request, context): + max_supported_total_tokens = self.model.warmup(request) + + # W/A for the skip tokenizer path + # We need to call make_tokenizer_optional after the warmup, + # because router is not aware of that feature + make_tokenizer_optional(self.model.tokenizer) + + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) + + async def Prefill(self, request, context): + start = time.time_ns() + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb_processor( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + + generations, next_batch, timings = self.model.generate_token([batch]) + self.cache.set(next_batch) + + return generate_pb2.PrefillResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + forward_ns=timings[0], + decode_ns=timings[1], + total_ns=time.time_ns() - start, + ) + + async def Decode(self, request, context): + start = time.time_ns() + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") + + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) == 0: + raise ValueError("All batches are empty") + + generations, next_batch, timings = self.model.generate_token(batches) + self.cache.set(next_batch) + + return generate_pb2.DecodeResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + concat_ns=None, + forward_ns=timings[0], + decode_ns=timings[1], + total_ns=time.time_ns() - start, + ) + + +def serve( + model_id: str, + lora_adapters: Optional[List[AdapterInfo]], + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[str], + trust_remote_code: bool, + uds_path: Path, + max_input_tokens: int, +): + async def serve_inner( + model_id: str, + lora_adapters: Optional[List[AdapterInfo]], + revision: Optional[str], + sharded: bool = False, + quantize: Optional[str] = None, + speculate: Optional[int] = None, + dtype: Optional[str] = None, + trust_remote_code: bool = False, + ): + if not is_driver_compatible(): + logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures") + + unix_socket_template = "unix://{}-{}" + adapter_to_index = {} + logger.info("Server:server_inner: sharded ={}".format(sharded)) + + if sharded: + rank = int(os.environ["RANK"]) + logger.info("Server:server_inner: rank ={}".format(rank)) + server_urls = [ + unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"])) + ] + local_url = server_urls[int(os.environ["RANK"])] + else: + local_url = unix_socket_template.format(uds_path, 0) + server_urls = [local_url] + + logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url)) + if dtype == "bfloat16" or None: + data_type = torch.bfloat16 + else: + data_type = torch.float + if revision == "None": + revision = None + try: + model = get_model_with_lora_adapters( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + data_type, + trust_remote_code, + max_input_tokens, + adapter_to_index, + ) + + except Exception: + logger.exception("Error when initializing model") + raise + + set_adapter_to_index(adapter_to_index) + server = aio.server( + interceptors=[ + ExceptionInterceptor(), + UDSOpenTelemetryAioServerInterceptor(), + ], + options=[ + # Set the maximum possible message length: i32::MAX + ("grpc.max_receive_message_length", (1 << 31) - 1) + ], + ) + generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( + TextGenerationService(model, Cache(), server_urls), server + ) + SERVICE_NAMES = ( + generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(SERVICE_NAMES, server) + server.add_insecure_port(local_url) + + await server.start() + + logger.info("Server started at {}".format(local_url)) + signal_handler = SignalHandler() + while signal_handler.KEEP_PROCESSING: + await asyncio.sleep(0.5) + + set_model_id(model_id) + asyncio.run( + serve_inner( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + ) + ) diff --git a/backends/gaudi/server/text_generation_server/tgi_service.py b/backends/gaudi/server/text_generation_server/tgi_service.py new file mode 100644 index 000000000..f0f131268 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/tgi_service.py @@ -0,0 +1,45 @@ +import os +from pathlib import Path +from loguru import logger +import sys +from text_generation_server import server +import argparse +from typing import List +from text_generation_server.utils.adapter import parse_lora_adapters + + +def main(args): + logger.info("TGIService: starting tgi service .... ") + logger.info( + "TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format( + args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path + ) + ) + lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) + server.serve( + model_id=args.model_id, + lora_adapters=lora_adapters, + revision=args.revision, + sharded=args.sharded, + quantize=args.quantize, + speculate=args.speculate, + dtype=args.dtype, + trust_remote_code=args.trust_remote_code, + uds_path=args.uds_path, + max_input_tokens=args.max_input_tokens + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str) + parser.add_argument("--revision", type=str) + parser.add_argument("--sharded", type=bool) + parser.add_argument("--speculate", type=int, default=None) + parser.add_argument("--dtype", type=str) + parser.add_argument("--trust_remote_code", type=bool) + parser.add_argument("--uds_path", type=Path) + parser.add_argument("--quantize", type=str) + parser.add_argument("--max_input_tokens", type=int) + args = parser.parse_args() + main(args) diff --git a/backends/gaudi/server/text_generation_server/tracing.py b/backends/gaudi/server/text_generation_server/tracing.py new file mode 100644 index 000000000..bc7a04ee7 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/tracing.py @@ -0,0 +1,63 @@ +import grpc + +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.grpc._aio_server import ( + OpenTelemetryAioServerInterceptor, +) +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, +) + + +class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor): + def __init__(self): + super().__init__(trace.get_tracer(__name__)) + + def _start_span(self, handler_call_details, context, set_status_on_exception=False): + """ + Rewrite _start_span method to support Unix Domain Socket gRPC contexts + """ + + # standard attributes + attributes = { + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0], + } + + # if we have details about the call, split into service and method + if handler_call_details.method: + service, method = handler_call_details.method.lstrip("/").split("/", 1) + attributes.update( + { + SpanAttributes.RPC_METHOD: method, + SpanAttributes.RPC_SERVICE: service, + } + ) + + # add some attributes from the metadata + metadata = dict(context.invocation_metadata()) + if "user-agent" in metadata: + attributes["rpc.user_agent"] = metadata["user-agent"] + + # We use gRPC over a UNIX socket + attributes.update({SpanAttributes.NET_TRANSPORT: "unix"}) + + return self._tracer.start_as_current_span( + name=handler_call_details.method, + kind=trace.SpanKind.SERVER, + attributes=attributes, + set_status_on_exception=set_status_on_exception, + ) + + +def setup_tracing(otlp_service_name: str, otlp_endpoint: str): + resource = Resource.create(attributes={"service.name": otlp_service_name}) + span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) + span_processor = BatchSpanProcessor(span_exporter) + + trace.set_tracer_provider(TracerProvider(resource=resource)) + trace.get_tracer_provider().add_span_processor(span_processor) diff --git a/backends/gaudi/server/text_generation_server/utils/__init__.py b/backends/gaudi/server/text_generation_server/utils/__init__.py new file mode 100644 index 000000000..565a7c3ca --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/__init__.py @@ -0,0 +1,48 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import text_generation_server.habana_quantization_env +from text_generation_server.utils.convert import convert_file, convert_files +from text_generation_server.utils.dist import initialize_torch_distributed +from text_generation_server.utils.weights import Weights +from text_generation_server.utils.peft import download_and_unload_peft +from text_generation_server.utils.hub import ( + weight_files, + weight_hub_files, + download_weights, + EntryNotFoundError, + LocalEntryNotFoundError, + RevisionNotFoundError, +) +from text_generation_server.utils.tokens import ( + NextTokenChooser, + HeterogeneousNextTokenChooser, + StoppingCriteria, + StopSequenceCriteria, + FinishReason, + Sampling, + Greedy, + make_tokenizer_optional, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, +) + +__all__ = [ + "convert_file", + "convert_files", + "initialize_torch_distributed", + "weight_files", + "weight_hub_files", + "download_weights", + "download_and_unload_peft", + "EntryNotFoundError", + "HeterogeneousNextTokenChooser", + "LocalEntryNotFoundError", + "RevisionNotFoundError", + "Greedy", + "NextTokenChooser", + "Sampling", + "StoppingCriteria", + "StopSequenceCriteria", + "FinishReason", + "Weights", +] diff --git a/backends/gaudi/server/text_generation_server/utils/adapter.py b/backends/gaudi/server/text_generation_server/utils/adapter.py new file mode 100644 index 000000000..2b61f9bb4 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/adapter.py @@ -0,0 +1,320 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/adapter.py +# License: Apache License Version 2.0, January 2004 + +import warnings +import re +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING, Set, Tuple, Optional, List + +from safetensors.torch import load_file +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer + +from text_generation_server.utils.merges.strategies import merge_adapters + +from text_generation_server.utils import hub +from text_generation_server.adapters.lora import LoraConfig + + +if TYPE_CHECKING: + from text_generation_server.adapters.config import AdapterConfig, ModuleMap + + +BASE_MODEL_ADAPTER_ID = "__base_model__" + + +@dataclass +class AdapterInfo: + id: str + path: Optional[str] + revision: Optional[str] = None + + +@dataclass +class AdapterParameters: + adapter_info: Tuple[AdapterInfo] + weights: Tuple[float] + merge_strategy: NotImplemented + density: float + majority_sign_method: NotImplemented + + +@dataclass +class AdapterSource: + adapter_id: str + model_id: str + revision: str + + +def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]: + if not lora_adapters: + return [] + + adapter_list = [] + for adapter in lora_adapters.split(","): + adapter = adapter.strip() + if adapter.count("=") > 1 or adapter.count("@") > 1: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter) + + if match: + adapter_id, path, revision = match.groups() + adapter_list.append( + AdapterInfo(id=adapter_id, path=path, revision=revision) + ) + else: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + return adapter_list + + +def load_and_merge_adapters( + model_id: str, + adapter_parameters: AdapterParameters, + adapter_index: int, + weight_names: Tuple[str], + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + if len(adapter_parameters.adapter_info) == 1: + adapter = next(iter(adapter_parameters.adapter_info)) + return load_module_map( + model_id, + adapter.revision, + adapter.id, + adapter.path, + weight_names, + trust_remote_code, + ) + + adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) + return _load_and_merge( + model_id, + adapter_params, + weight_names, + trust_remote_code, + ) + + +@dataclass +class AdapterParametersContainer: + adapter_parameters: AdapterParameters + adapter_index: int + + def __hash__(self) -> int: + return self.adapter_index + + +@lru_cache(maxsize=32) +def _load_and_merge( + model_id: str, + adapter_params: AdapterParametersContainer, + weight_names: Tuple[str], + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + params = adapter_params.adapter_parameters + + adapters_to_merge = [] + merged_weight_names = set() + tokenizer = None + for adapter in params.adapter_info: + if adapter.id == BASE_MODEL_ADAPTER_ID: + raise ValueError("Base model adapter cannot be merged.") + + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( + load_module_map( + model_id, + adapter.revision, + adapter.id, + adapter.path, + weight_names, + trust_remote_code, + ) + ) + + adapters_to_merge.append((module_map, adapter_config)) + merged_weight_names = merged_weight_names.union(adapter_weight_names) + if tokenizer is None: + tokenizer = adapter_tokenizer + + if len(adapters_to_merge) == 0: + raise ValueError("No adapters to merge.") + + module_map, adapter_config = merge_adapters(adapters_to_merge, params) + return module_map, adapter_config, merged_weight_names, tokenizer + + +def check_architectures( + model_id: str, + adapter_id: str, + adapter_config: "AdapterConfig", + trust_remote_code: bool = False, +): + try: + if not adapter_config.base_model_name_or_path: + # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None) + return + + expected_config = AutoConfig.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + model_config = AutoConfig.from_pretrained( + adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code + ) + except Exception as e: + warnings.warn( + f"Unable to check architecture compatibility for adapter '{adapter_id}' " + f"against model '{model_id}'. Assuming they are compatible. Error: {e}" + ) + return + + if model_config.architectures == expected_config.architectures: + warnings.warn( + f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. " + f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + else: + # TODO(travis): revisit this when we support clasification heads which will not use CausalLM + raise ValueError( + f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + + +@lru_cache(maxsize=128) +def load_module_map( + model_id: str, + revision: str, + adapter_id: str, + adapter_path: Optional[str], + weight_names: Tuple[str], + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + adapter_config = LoraConfig.load(adapter_path or adapter_id, None) + + if not adapter_path and adapter_config.base_model_name_or_path != model_id: + check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) + + adapter_filenames = ( + hub._weight_files_from_dir(adapter_path, extension=".safetensors") + if adapter_path + else hub._cached_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) + ) + + # throw an error if no adapter weights are found + if not adapter_filenames: + raise FileNotFoundError( + f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'." + ) + + try: + adapter_tokenizer = AutoTokenizer.from_pretrained( + adapter_config.config_path, + trust_remote_code=trust_remote_code, + ) + except Exception: + # Adapter does not have a tokenizer, so fallback to base model tokenizer + adapter_tokenizer = None + + # load adapter weights from all shards (should have relatively small memory footprint) + adapter_weights = {} + for filename in adapter_filenames: + adapter_weights.update(load_file(filename)) + + # map the model weights to the relevant adapter weights (LoRA A and B matrices) + module_map, adapter_weight_names = adapter_config.map_weights_for_model( + adapter_weights, weight_names + ) + return module_map, adapter_config, adapter_weight_names, adapter_tokenizer + + +def get_attn_weights(i, layer): + qkv = layer.self_attn.query_key_value + weights = {} + + for k in ["q", "k", "v"]: + key = (i, f"{k}_proj") + value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) + weights[key] = value + + # also add the qkv_proj weight for the adapter + weights[(i, "qkv_proj")] = ( + f"model.layers.{i}.self_attn.qkv_proj", + qkv, + ) + + weights[(i, "o_proj")] = ( + f"model.layers.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + return weights + + +def get_mlp_weights(i, layer): + weights = {} + if hasattr(layer, "mlp"): + mlp = layer.mlp + if hasattr(mlp, "gate_up_proj"): + # handle combined gate_up_proj (e.g., for some LLaMA variants) + weights.update( + { + (i, "gate_proj"): ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_up_proj, + ), + (i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj), + } + ) + else: + # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma) + if hasattr(mlp, "gate_proj"): + weights[(i, "gate_proj")] = ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_proj, + ) + if hasattr(mlp, "up_proj"): + weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) + + if hasattr(mlp, "down_proj"): + weights[(i, "down_proj")] = ( + f"model.layers.{i}.mlp.down_proj", + mlp.down_proj, + ) + + return weights + + +# build_layer_weight_lookup creates a mapping of model layers to their corresponding +# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples +# containing the weight tensor path and the actual layer object. This mapping is needed +# for the lora adapter to know which weights to update when applying the adapter. +def build_layer_weight_lookup(model): + if hasattr(model, "language_model"): + m = model.language_model.model + elif hasattr(model, "text_model"): + m = model.text_model.model + else: + m = model.model + + layer_weights = {} + + for i, layer in enumerate(m.layers): + attn_weights = get_attn_weights(i, layer) + mlp_weights = get_mlp_weights(i, layer) + + layer_weights.update(attn_weights) + layer_weights.update(mlp_weights) + + lm_head = None + if hasattr(m, "lm_head"): + lm_head = m.lm_head + elif hasattr(model, "lm_head"): + lm_head = model.lm_head + + if lm_head: + layer_weights[(0, "lm_head")] = ("lm_head", lm_head) + + return layer_weights diff --git a/backends/gaudi/server/text_generation_server/utils/chunks.py b/backends/gaudi/server/text_generation_server/utils/chunks.py new file mode 100644 index 000000000..73962ea39 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/chunks.py @@ -0,0 +1,27 @@ +from typing import Iterable + +from loguru import logger + +from text_generation_server.pb import generate_pb2 + + +def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str: + """ + Concatenate text in text chunks. Non-text chunks are dropped. + """ + text = None + for chunk in chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + if text is None: + text = chunk.text + else: + raise NotImplementedError("Request contained more than one text chunk") + else: + # We cannot reject this, e.g. warmup sends an image chunk. + logger.debug(f"Encountered non-text chunk type {chunk_type}") + + if text is None: + raise NotImplementedError("Request without a text chunk") + + return text diff --git a/backends/gaudi/server/text_generation_server/utils/convert.py b/backends/gaudi/server/text_generation_server/utils/convert.py new file mode 100644 index 000000000..d9c3276bc --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/convert.py @@ -0,0 +1,114 @@ +import datetime +import torch +import os + +from loguru import logger +from pathlib import Path +from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete +from typing import List, Dict +from collections import defaultdict + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])] + ) + if not complete_names: + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mecanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): + """ + Convert a pytorch file to a safetensors file + This will remove duplicate tensors from the file. + + Unfortunately, this might not respect *transformers* convention. + Forcing us to check for potentially different keys during load when looking + for specific tensors (making tensor sharing explicit). + """ + loaded = torch.load(pt_file, map_location="cpu", weights_only=True) + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_file) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]): + assert len(pt_files) == len(sf_files) + + N = len(pt_files) + # We do this instead of using tqdm because we want to parse the logs with the launcher + + for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): + # Skip blacklisted files + if ( + "arguments" in pt_file.name + or "args" in pt_file.name + or "training" in pt_file.name + ): + continue + + start = datetime.datetime.now() + convert_file(pt_file, sf_file, discard_names) + elapsed = datetime.datetime.now() - start + logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}") diff --git a/backends/gaudi/server/text_generation_server/utils/debug.py b/backends/gaudi/server/text_generation_server/utils/debug.py new file mode 100644 index 000000000..ef8d437b7 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/debug.py @@ -0,0 +1,31 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import os +import glob +import time + +from optimum.habana.utils import to_gb_rounded +import habana_frameworks.torch as htorch + +START_TS = None +DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') +if 'GRAPH_VISUALIZATION' in os.environ: + for f in glob.glob('.graph_dumps/*'): + os.remove(f) + + +def count_hpu_graphs(): + return len(glob.glob('.graph_dumps/*PreGraph*')) + + +def dbg_trace(tag, txt): + global START_TS + if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0: + if START_TS is None: + START_TS = time.perf_counter() + time_offset = time.perf_counter() - START_TS + mem_stats = htorch.hpu.memory.memory_stats() + mem_used = to_gb_rounded(mem_stats['InUse']) + max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) + print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' + f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a')) diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py new file mode 100644 index 000000000..d370a3d5c --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -0,0 +1,99 @@ +import os +import torch + +from datetime import timedelta +from loguru import logger + +# Tensor Parallelism settings +RANK = int(os.getenv("RANK", "0")) +WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + +# CUDA memory fraction +MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) + + +class FakeBarrier: + def wait(self): + pass + + +class FakeGroup: + def __init__(self, rank, size): + self._rank = rank + self._size = size + + def allreduce(self, *args, **kwargs): + return FakeBarrier() + + def allgather(self, inputs, local_tensor, **kwargs): + assert ( + len(inputs[0]) == len(local_tensor) == 1 + ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" + for input_ in inputs: + input_[0].data = local_tensor[0].data + return FakeBarrier() + + def barrier(self, *args, **kwargs): + return FakeBarrier() + + def size(self): + return self._size + + def rank(self): + return self._rank + + +def initialize_torch_distributed(): + import habana_frameworks.torch.core as htcore + + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + options = None + if torch.cuda.is_available(): + from torch.distributed import ProcessGroupNCCL + + # Set the device id. + assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" + device = RANK % torch.cuda.device_count() + torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) + backend = "nccl" + options = ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + options._timeout = timedelta(seconds=60) + elif torch.hpu.is_available(): + backend = "hccl" + n_hpus = torch.hpu.device_count() + if world_size > n_hpus: + raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") + else: + try: + import oneccl_bindings_for_pytorch + + backend = "ccl" + if os.getenv("CCL_WORKER_COUNT", None) is None: + os.environ["CCL_WORKER_COUNT"] = str(1) + except ImportError: + backend = "gloo" + options = None + + if WORLD_SIZE == 1: + return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE + else: + if os.getenv("DEBUG", None) == "1": + return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE + + if not torch.distributed.is_initialized(): + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + else: + logger.warning("torch.distributed is already initialized.") + + return torch.distributed.group.WORLD, RANK, WORLD_SIZE diff --git a/backends/gaudi/server/text_generation_server/utils/hub.py b/backends/gaudi/server/text_generation_server/utils/hub.py new file mode 100644 index 000000000..f9c476ac3 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/hub.py @@ -0,0 +1,234 @@ +import time +import os + +from datetime import timedelta +from loguru import logger +from pathlib import Path +from typing import Optional, List + +from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import ( + LocalEntryNotFoundError, + EntryNotFoundError, + RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib +) + +WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) +HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] + + +def _cached_weight_files( + model_id: str, revision: Optional[str], extension: str +) -> List[str]: + """Guess weight files from the cached revision snapshot directory""" + d = _get_cached_revision_directory(model_id, revision) + if not d: + return [] + filenames = _weight_files_from_dir(d, extension) + return filenames + + +def _weight_hub_files_from_model_info( + info: hf_api.ModelInfo, extension: str +) -> List[str]: + return [ + s.rfilename + for s in info.siblings + if s.rfilename.endswith(extension) + and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename + and "args" not in s.rfilename + and "training" not in s.rfilename + ] + + +def _weight_files_from_dir(d: Path, extension: str) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_hub_files_from_model_info, that's also what is + # done there with the len(s.rfilename.split("/")) == 1 condition + root, _, files = next(os.walk(str(d))) + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "training" not in f + ] + return filenames + + +def _get_cached_revision_directory( + model_id: str, revision: Optional[str] +) -> Optional[Path]: + if revision is None: + revision = "main" + + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path( + file_download.repo_folder_name(repo_id=model_id, repo_type="model") + ) + + if not repo_cache.is_dir(): + # No cache for this model + return None + + refs_dir = repo_cache / "refs" + snapshots_dir = repo_cache / "snapshots" + + # Resolve refs (for instance to convert main to the associated commit sha) + if refs_dir.is_dir(): + revision_file = refs_dir / revision + if revision_file.exists(): + with revision_file.open() as f: + revision = f.read() + + # Check if revision folder exists + if not snapshots_dir.exists(): + return None + cached_shas = os.listdir(snapshots_dir) + if revision not in cached_shas: + # No cache for this revision and we won't try to return a random revision + return None + + return snapshots_dir / revision + + +def weight_hub_files( + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" +) -> List[str]: + """Get the weights filenames on the hub""" + api = HfApi() + + if HF_HUB_OFFLINE: + filenames = _cached_weight_files(model_id, revision, extension) + else: + # Online case, fetch model info from the Hub + info = api.model_info(model_id, revision=revision) + filenames = _weight_hub_files_from_model_info(info, extension) + + if not filenames: + raise EntryNotFoundError( + f"No {extension} weights found for model {model_id} and revision {revision}.", + None, + ) + + return filenames + + +def try_to_load_from_cache( + model_id: str, revision: Optional[str], filename: str +) -> Optional[Path]: + """Try to load a file from the Hugging Face cache""" + + d = _get_cached_revision_directory(model_id, revision) + if not d: + return None + + # Check if file exists in cache + cached_file = d / filename + return cached_file if cached_file.is_file() else None + + +def weight_files( + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" +) -> List[Path]: + """Get the local files""" + # Local model + d = Path(model_id) + if d.exists() and d.is_dir(): + local_files = _weight_files_from_dir(d, extension) + if not local_files: + raise FileNotFoundError( + f"No local weights found in {model_id} with extension {extension}" + ) + return [Path(f) for f in local_files] + + try: + filenames = weight_hub_files(model_id, revision, extension) + except EntryNotFoundError as e: + if extension != ".safetensors": + raise e + # Try to see if there are pytorch weights + pt_filenames = weight_hub_files(model_id, revision, extension=".bin") + # Change pytorch extension to safetensors extension + # It is possible that we have safetensors weights locally even though they are not on the + # hub if we converted weights locally without pushing them + filenames = [ + f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames + ] + + if WEIGHTS_CACHE_OVERRIDE is not None: + files = [] + for filename in filenames: + p = Path(WEIGHTS_CACHE_OVERRIDE) / filename + if not p.exists(): + raise FileNotFoundError( + f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." + ) + files.append(p) + return files + + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache( + model_id, revision=revision, filename=filename + ) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_id} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `text-generation-server download-weights {model_id}` first." + ) + files.append(cache_file) + + return files + + +def download_weights( + filenames: List[str], model_id: str, revision: Optional[str] = None +) -> List[Path]: + """Download the safetensors files from the hub""" + + def download_file(fname, tries=5, backoff: int = 5): + local_file = try_to_load_from_cache(model_id, revision, fname) + if local_file is not None: + logger.info(f"File {fname} already present in cache.") + return Path(local_file) + + for idx in range(tries): + try: + logger.info(f"Download file: {fname}") + stime = time.time() + local_file = hf_hub_download( + filename=fname, + repo_id=model_id, + revision=revision, + local_files_only=HF_HUB_OFFLINE, + ) + logger.info( + f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}." + ) + return Path(local_file) + except Exception as e: + if idx + 1 == tries: + raise e + logger.error(e) + logger.info(f"Retrying in {backoff} seconds") + time.sleep(backoff) + logger.info(f"Retry {idx + 1}/{tries - 1}") + + # We do this instead of using tqdm because we want to parse the logs with the launcher + start_time = time.time() + files = [] + for i, filename in enumerate(filenames): + file = download_file(filename) + + elapsed = timedelta(seconds=int(time.time() - start_time)) + remaining = len(filenames) - (i + 1) + eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0 + + logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}") + files.append(file) + + return files diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py new file mode 100644 index 000000000..782b4f15b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -0,0 +1,75 @@ +import torch +from loguru import logger +import os + + +import importlib.util + + +def is_ipex_available(): + return importlib.util.find_spec("intel_extension_for_pytorch") is not None + + +def get_cuda_free_memory(device, memory_fraction): + total_free_memory, _ = torch.cuda.mem_get_info(device) + total_gpu_memory = torch.cuda.get_device_properties(device).total_memory + free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) + return free_memory + + +def get_xpu_free_memory(device, memory_fraction): + total_memory = torch.xpu.get_device_properties(device).total_memory + device_id = device.index + memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + free_memory = max( + 0, + int( + total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) + ), + ) + return free_memory + + +def get_cpu_free_memory(device, memory_fraction): + import psutil + from text_generation_server.utils.dist import WORLD_SIZE + + mem = psutil.virtual_memory() + free_memory = int(mem.available * 0.95 / WORLD_SIZE) + return free_memory + + +def noop(*args, **kwargs): + pass + + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif torch.version.cuda is not None and torch.cuda.is_available(): + SYSTEM = "cuda" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif is_ipex_available(): + SYSTEM = "ipex" + import intel_extension_for_pytorch # noqa: F401 + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory + else: + empty_cache = noop + synchronize = noop + get_free_memory = get_cpu_free_memory +else: + SYSTEM = "cpu" + + empty_cache = noop + synchronize = noop + get_free_memory = get_cpu_free_memory +logger.info(f"Detected system {SYSTEM}") diff --git a/backends/gaudi/server/text_generation_server/utils/log.py b/backends/gaudi/server/text_generation_server/utils/log.py new file mode 100644 index 000000000..4385c71ee --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/log.py @@ -0,0 +1,15 @@ +from functools import lru_cache +from text_generation_server.utils.dist import RANK + + +@lru_cache(10) +def log_once(log, msg: str, master=True): + if master: + log_master(log, msg) + else: + log(msg) + + +def log_master(log, msg: str): + if RANK == 0: + log(msg) diff --git a/backends/gaudi/server/text_generation_server/utils/logits_process.py b/backends/gaudi/server/text_generation_server/utils/logits_process.py new file mode 100644 index 000000000..104fc2f09 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/logits_process.py @@ -0,0 +1,583 @@ +import math +import torch +import habana_frameworks.torch.core as htcore + +from loguru import logger +from typing import Dict, Union +from text_generation_server.pb.generate_pb2 import GrammarType + +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_schema +from functools import lru_cache +from typing import List, Optional, DefaultDict +import time + +from transformers import ( + LogitsWarper, + LogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, +) + +mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + + +class StaticWarper: + def __init__( + self, + temperature=1.0, + top_k=None, + top_p=None, + typical_p=None, + ): + self.warpers = [] + + if temperature is not None and temperature != 1.0: + temperature = float(temperature) + self.warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + self.warpers.append(TopKLogitsWarper(top_k=top_k)) + if top_p is not None and top_p < 1.0: + self.warpers.append(TopPLogitsWarper(top_p=top_p)) + if typical_p is not None and typical_p < 1.0: + self.warpers.append(TypicalLogitsWarper(mass=typical_p)) + + self.hpu_graph = None + self.static_scores = None + self.static_warped_scores = None + self.static_next_logprob = None + + def __call__(self, scores): + if self.hpu_graph is None: + self.static_scores = scores.clone().contiguous() + self.static_warped_scores = scores.clone().contiguous() + self.static_next_logprob = scores.clone().contiguous() + self.hpu_graph = htcore.hpu.HPUGraph() + + with htcore.hpu.graph(self.hpu_graph): + local_scores = self.static_scores + for warper in self.warpers: + local_scores = warper(None, local_scores) + + self.static_warped_scores.copy_(local_scores) + # Compute logprobs + self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1)) + + self.static_scores.copy_(scores) + self.hpu_graph.replay() + + return self.static_warped_scores, self.static_next_logprob + + +@lru_cache(10) +def static_warper( + temperature: Optional[float], + top_k: Optional[int], + top_p: Optional[float], + typical_p: Optional[float], +) -> StaticWarper: + return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) + + +class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + repetition_penalty (`List[float]`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + """ + + def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): + self.penalty = penalty + self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + score = torch.gather(scores, 1, input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor) + + scores.scatter_(1, input_ids, score) + return scores + + def filter(self, indices): + self.penalty = [self.penalty[i] for i in indices] + if any([x != 1.0 for x in self.penalty]): + self.penalty_tensor = self.penalty_tensor[indices] + return self + return None + + +class FrequencyPenaltyLogitsProcessor(LogitsProcessor): + r""" + Frequency penalty as defined by OpenAI + + Args: + penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. + """ + + def __init__(self, penalty: float): + self.penalty = penalty + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + score = torch.gather(scores, 1, input_ids) + # if score < 0 then penalty has to be multiplied to reduce the previous token probability + score = -torch.where(score < 0, score * self.penalty, score / self.penalty) + # set score to 0 where input_ids is a padding token + score *= input_ids.ne(0) + + return scores.scatter_add_(1, input_ids, score) + + +class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): + r""" + Frequency penalty as defined by OpenAI in + https://platform.openai.com/docs/guides/text-generation/parameter-details + + Args: + frequency_penalty (`List[float]`): + The parameter for frequency penalty. 0.0 means no penalty. + """ + + def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): + self.penalty = penalty + self.penalty_tensor = torch.tensor( + penalty, dtype=dtype, device=device + ).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + batch_size, input_size = input_ids.size() + vocab_size = scores.size(1) + + # Calculate the frequency for each token so far + token_freq = torch.zeros( + batch_size, vocab_size, dtype=scores.dtype, device=scores.device + ) + token_freq.scatter_add_( + 1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device) + ) + token_freq /= input_size + + # Apply the frequency penalty to logits + scores -= token_freq * self.penalty_tensor + return scores + + def filter(self, indices): + self.penalty = [self.penalty[i] for i in indices] + if any([x != 0.0 for x in self.penalty]): + self.penalty_tensor = self.penalty_tensor[indices] + return self + return None + + +class HeterogeneousTemperatureLogitsWarper: + r""" + [`LogitsWarper`] for temperature (exponential scaling output probability distribution). + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + temperature (`float`): + The value used to module the logits distribution. + """ + + def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device): + self.temperature = temperature + self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + scores.div_(self.temperature_tensor) + return scores + + def filter(self, indices): + self.temperature = [self.temperature[i] for i in indices] + if any([x != 1.0 for x in self.temperature]): + self.temperature_tensor = self.temperature_tensor[indices] + return self + return None + + +class HeterogeneousTopPLogitsWarper(LogitsWarper): + """ + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + top_p: List[float], + dtype: torch.dtype, + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.top_p = top_p + self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1) + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + sorted_logits, sorted_indices = torch.sort(scores, descending=False) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = probs <= self.top_p_opposite + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) + + return warped_scores + + def filter(self, indices): + self.top_p = [self.top_p[i] for i in indices] + if any([x < 1.0 for x in self.top_p]): + self.top_p_opposite = self.top_p_opposite[indices] + return self + return None + + +class HeterogeneousTopKLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + top_k: List[int], + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.top_k = top_k + self.max_top_k = max(top_k) + # value - 1 as we will use top_k to index and python uses 0 based numbering + self.top_k_tensor = torch.tensor( + [max(x - 1, min_tokens_to_keep - 1) for x in top_k], + dtype=torch.int64, + device=device, + ).unsqueeze(1) + + # 0 is a special value that disables top_k warping for this member of the batch + disabled = [x == 0 for x in top_k] + + if any(disabled): + self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1) + else: + self.top_k_disabled_mask = None + + self.filter_value = filter_value + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # If max_top_k is superior to the vocab, we need to clamp or the warper will fail + if scores.size(-1) < self.max_top_k: + max_top_k = scores.size(-1) + top_k = torch.clamp_max(self.top_k_tensor, max_top_k) + else: + max_top_k = self.max_top_k + top_k = self.top_k_tensor + + # Get the kth score for each member of the batch + kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) + + # Mask member of kth_scores that do not want to use top_k warping + if self.top_k_disabled_mask is not None: + kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value) + + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < kth_scores + scores.masked_fill_(indices_to_remove, self.filter_value) + return scores + + def filter(self, indices): + self.top_k = [self.top_k[i] for i in indices] + disabled = [x == 0 for x in self.top_k] + + if not all(disabled): + self.top_k_tensor = self.top_k_tensor[indices] + self.max_top_k = max(self.top_k) + + if self.top_k_disabled_mask is not None: + self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None + + return self + return None + + +class HeterogeneousTypicalLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language + Generation](https://arxiv.org/abs/2202.00666) for more information. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + mass (`float`): + Value of typical_p between 0 and 1 inclusive, defaults to 0.9. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + mass: List[float], + dtype: torch.dtype, + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.mass = mass + self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) + + # 1 is a special value that disables typical_p warping for this member of the batch + disabled = [x == 1.0 for x in mass] + + if any(disabled): + self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device) + else: + self.disabled_mask = None + + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (probs < self.mass_tensor).sum(dim=1) + last_ind[last_ind < 0] = 0 + + if self.disabled_mask is not None: + last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) + + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) + + return warped_scores + + def filter(self, indices): + self.mass = [self.mass[i] for i in indices] + disabled = [x == 1.0 for x in self.mass] + + if not all(disabled): + self.mass_tensor = self.mass_tensor[indices] + + if self.disabled_mask is not None: + self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None + + return self + return None + + +class HeterogeneousProcessorWrapper(LogitsProcessor): + r""" + A wrapper for logit warpers or processors without heterogeneous parameter support. + Args: + processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): + A mapping of sample indices to logit warpers or processors, to be run sequentially. + """ + + def __init__( + self, + processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], + ): + self.processors = processors + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + for i, processor in self.processors.items(): + scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) + return scores + + def filter(self, indices): + new_processors = {} + for i, idx in enumerate(indices): + if idx in self.processors: + new_processors[i] = self.processors[idx] + + if new_processors: + self.processors = new_processors + return self + return None + + +class GrammarLogitProcessor(LogitsProcessor): + fsm_state: DefaultDict[int, int] + fsm: RegexFSM + + def __init__(self, tokenizer, device, grammar, grammar_type): + self.device = device + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type, grammar, self.tokenizer + ) + + def __call__( + self, + logits: torch.Tensor, + fsm_grammar_state: int, + ): + if fsm_grammar_state == -1 or self.fsm is None: + return logits + allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) + mask = torch.full_like(logits, -math.inf) + mask[:, allowed_tokens] = 0 + biased_scores = logits + mask + return biased_scores + + def advance(self, next_token_id, fsm_grammar_state): + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsm + ) + + @staticmethod + def _advance(next_token_id, fsm_grammar_state, fsm): + if fsm_grammar_state == -1: + return fsm_grammar_state + return fsm.next_state(fsm_grammar_state, next_token_id) + + # TODO: move grammar compilation into the router + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_compile_fsm(grammar_type, schema, tokenizer): + start_time = time.time() + if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: + schema = build_regex_from_schema(schema) + elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: + pass # schema is already a regex just here for clarity + fsm = RegexFSM(schema, tokenizer) + logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") + return fsm + + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_adapt_tokenizer(tokenizer): + """Adapt tokenizer to work with the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + start_time = time.time() + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s") + return tokenizer + + +class HeterogeneousGrammarLogitProcessor(LogitsProcessor): + def __init__(self, tokenizer, device, grammars, grammar_types): + self.device = device + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsms = [] + for grammar, grammar_type in zip(grammars, grammar_types): + if len(grammar) == 0: + self.fsms.append(None) + continue + fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type, grammar, self.tokenizer + ) + self.fsms.append(fsm) + + def __call__( + self, + logits: torch.Tensor, + fsm_grammar_states: List[int], + ): + mask = torch.full_like(logits, -math.inf) + for i in range(logits.shape[0]): + fsm = self.fsms[i] + if fsm is None: + continue + allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) + mask[i, allowed_tokens] = 0 + logits[i] += mask[i] + return logits + + def advance_batch(self, next_token_ids, fsm_grammar_states): + return [ + GrammarLogitProcessor._advance( + next_token_ids[i], fsm_grammar_states[i], self.fsms[i] + ) + for i in range(len(next_token_ids)) + ] + + def advance_at_index(self, next_token_id, fsm_grammar_state, index): + if self.fsms[index] is None: + return fsm_grammar_state + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsms[index] + ) + + def filter(self, indices): + new_fsms = [] + for i in indices: + new_fsms.append(self.fsms[i]) + self.fsms = new_fsms + return self diff --git a/backends/gaudi/server/text_generation_server/utils/merges/strategies.py b/backends/gaudi/server/text_generation_server/utils/merges/strategies.py new file mode 100644 index 000000000..cb39cde1f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/merges/strategies.py @@ -0,0 +1,220 @@ +import copy +from abc import ABC +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union +from text_generation_server.utils.merges.utils import ( + calculate_majority_sign_mask, + disjoint_merge, + prune, +) +import torch + +if TYPE_CHECKING: + from text_generation_server.adapters.lora import LoraConfig + from text_generation_server.utils.adapter import ModuleMap + + +class AdapterParameters: + def __init__( + self, adapter_ids, weights, merge_strategy, density, majority_sign_method + ): + self.adapter_ids = adapter_ids + self.weights = weights + self.merge_strategy = merge_strategy + self.density = density + self.majority_sign_method = majority_sign_method + + +def _apply_weights( + tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor +) -> torch.Tensor: + if isinstance(tensors, torch.Tensor): + t = tensors + else: + t = torch.stack(tensors, dim=0) + + # element-wise weighting of each task tensor + # need to unsqueeze weights to match task tensor dimensions + # for multiplication to apply element-wise + while len(t.shape) > len(w.shape): + w = w.unsqueeze(-1) + return t * w + + +class MergeStrategy(ABC): + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + raise NotImplementedError() + + +class LinearMerge(MergeStrategy): + def __init__(self, **kwargs): + pass + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class TiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="magnitude") for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + return disjoint_merge(weighted_task_tensors, majority_sign_mask) + + +class DareLinearMerge(MergeStrategy): + def __init__(self, density: float, **kwargs): + self.density = density + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class DareTiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +strategy_registry: Dict[str, Type[MergeStrategy]] = { + "linear": LinearMerge, + "ties": TiesMerge, + "dare_linear": DareLinearMerge, + "dare_ties": DareTiesMerge, +} + + +def merge_adapters( + adapters: List[Tuple["ModuleMap", "LoraConfig"]], + merge_params: AdapterParameters, +) -> Tuple["ModuleMap", "LoraConfig"]: + # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() + strategy_name = "linear" + + weights = merge_params.weights + if not weights: + weights = torch.ones(len(adapters)) + else: + weights = torch.tensor(weights) + + merge_config = { + "density": merge_params.density, + # "majority_sign_method": MajoritySignMethodEnum.Name( + # merge_params.majority_sign_method + # ).lower(), + "majority_sign_method": "total", + } + merge_strategy = strategy_registry[strategy_name](**merge_config) + + module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) + lora_configs = [] + weight_name_to_adapter_idx = defaultdict(list) + + # input is list of (module_map, lora_config) tuples + # convert into dict[k][param_name] -> list of tensors + for idx, (module_map, lora_config) in enumerate(adapters): + for weight_name, data in module_map.items(): + weight_name_to_adapter_idx[weight_name].append(idx) + for k, (param_data, param_name) in data.items(): + module_maps[weight_name][k][param_name].append(param_data) + lora_configs.append(lora_config) + + # validate lora configs are compatible + _validate_lora_configs(lora_configs) + + # merge tensors for each module such that we have a single ModuleMap: + # dict[k] -> merged tensor + merged_module_map: "ModuleMap" = defaultdict(dict) + for weight_name, data in module_maps.items(): + indices = weight_name_to_adapter_idx[weight_name] + param_weights = weights[indices] + for k, param_data in data.items(): + for param_name, tensors in param_data.items(): + merged_tensor = merge_strategy.merge(tensors, param_weights) + merged_module_map[weight_name][k] = (merged_tensor, param_name) + + # merge lora configs + merged_lora_config = _merge_lora_configs(lora_configs) + + return merged_module_map, merged_lora_config + + +def _validate_lora_configs(lora_configs: List["LoraConfig"]): + # check that all configs have the same rank + ranks = set(lora_config.r for lora_config in lora_configs) + if len(ranks) > 1: + raise ValueError( + f"unable to merge adapters, lora configs have different ranks: {ranks}" + ) + + if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): + raise ValueError( + "unable to merge adapters, lora configs have no target modules" + ) + + +def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig": + merged_lora_config = copy.copy(lora_configs[0]) + + # merge target modules as a union operation + merged_target_modules = sorted( + set( + module + for lora_config in lora_configs + for module in lora_config.target_modules + ) + ) + merged_lora_config.target_modules = merged_target_modules + + return merged_lora_config diff --git a/backends/gaudi/server/text_generation_server/utils/merges/utils.py b/backends/gaudi/server/text_generation_server/utils/merges/utils.py new file mode 100644 index 000000000..d9ad3278a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/merges/utils.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# From: https://github.com/huggingface/peft/pull/1364 +# Copyright 2024-present the HuggingFace Inc. team. +# Modifications by Predibase, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import torch + + +def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + """ + mask = torch.zeros_like(tensor).reshape(-1) + k = int(density * tensor.reshape(-1).shape[0]) + top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + torch.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: torch.Tensor, + density: float, + method: Literal["magnitude", "random"], + rescale: bool = False, +) -> torch.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + if density >= 1: + return tensor + elif density < 0: + raise ValueError("Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask( + tensor: torch.Tensor, method: Literal["total", "frequency"] = "total" +): + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`torch.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = (sign * tensor.abs()).sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = torch.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors, majority_sign_mask): + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0) diff --git a/backends/gaudi/server/text_generation_server/utils/peft.py b/backends/gaudi/server/text_generation_server/utils/peft.py new file mode 100644 index 000000000..d49e73f00 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/peft.py @@ -0,0 +1,68 @@ +import os +from typing import Union +from loguru import logger +import torch + +from transformers import AutoTokenizer +from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM + + +def download_and_unload_peft(model_id, revision, trust_remote_code): + torch_dtype = torch.float16 + + logger.info("Trying to load a Peft model. It might take a while without feedback") + try: + model = AutoPeftModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + except Exception: + model = AutoPeftModelForSeq2SeqLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + logger.info("Peft model detected.") + logger.info("Merging the lora weights.") + + base_model_id = model.peft_config["default"].base_model_name_or_path + + model = model.merge_and_unload() + + os.makedirs(model_id, exist_ok=True) + cache_dir = model_id + logger.info(f"Saving the newly created merged model to {cache_dir}") + tokenizer = AutoTokenizer.from_pretrained( + base_model_id, trust_remote_code=trust_remote_code + ) + model.save_pretrained(cache_dir, safe_serialization=True) + model.config.save_pretrained(cache_dir) + tokenizer.save_pretrained(cache_dir) + + +def download_peft( + model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool +): + torch_dtype = torch.float16 + try: + _model = AutoPeftModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + except Exception: + _model = AutoPeftModelForSeq2SeqLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + logger.info("Peft model downloaded.") diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py new file mode 100644 index 000000000..ee561acc4 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -0,0 +1,202 @@ +import json +import os +from dataclasses import dataclass +from typing import Optional + +from huggingface_hub import hf_hub_download +from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + WeightsLoader, +) + + +# TODO: Split this config to have a single config type per quant method +@dataclass +class _QuantizerConfig: + bits: int + checkpoint_format: Optional[str] + desc_act: bool + groupsize: int + quant_method: str + sym: bool + + +@dataclass +class _FP8QuantizerConfig: + activation_scale_ub: float + + +# We should probably do this with Pytantic JSON deserialization, +# but for now we'll stay close to the old _set_gptq_params. +def _get_quantizer_config(model_id, revision): + bits = 4 + groupsize = -1 + quant_method = "gptq" + checkpoint_format = None + sym = False + desc_act = False + + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + data = json.load(f) + + # FP8 config + if data["quantization_config"]["quant_method"] == "fbgemm_fp8": + return _FP8QuantizerConfig( + activation_scale_ub=data["quantization_config"]["activation_scale_ub"] + ) + + if "zero_point" in data["quantization_config"]: + sym = not data["quantization_config"]["zero_point"] + quant_method = "awq" + elif "sym" in data["quantization_config"]: + sym = data["quantization_config"]["sym"] + + bits = data["quantization_config"]["bits"] + groupsize = data["quantization_config"]["group_size"] + # Order is important here, desc_act is missing on some real models + quant_method = data["quantization_config"]["quant_method"] + checkpoint_format = data["quantization_config"].get("checkpoint_format") + desc_act = data["quantization_config"]["desc_act"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["bits"] + groupsize = data["group_size"] + + if "zero_point" in data: + sym = not data["zero_point"] + quant_method = "awq" + elif "sym" in data: + sym = data["sym"] + + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + filename = "quant_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["w_bit"] + groupsize = data["q_group_size"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + pass + + return _QuantizerConfig( + bits=bits, + groupsize=groupsize, + quant_method=quant_method, + checkpoint_format=checkpoint_format, + sym=sym, + desc_act=desc_act, + ) + + +def get_loader( + quantize: Optional[str], model_id: str, revision: Optional[str] +) -> WeightsLoader: + quantizer_config = _get_quantizer_config(model_id, revision) + if quantize in {"awq", "gptq"}: + from text_generation_server.layers.gptq import GPTQWeightsLoader + + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + + if can_use_gptq_marlin( + bits=quantizer_config.bits, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ): + from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader + + return GPTQMarlinWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + else: + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + elif quantize == "bitsandbytes": + from text_generation_server.layers.bnb import BNBWeight + + return DefaultWeightsLoader(BNBWeight) + elif quantize == "bitsandbytes-fp4": + from text_generation_server.layers.bnb import BNBFP4Weight + + return DefaultWeightsLoader(BNBFP4Weight) + elif quantize == "bitsandbytes-nf4": + from text_generation_server.layers.bnb import BNBNF4Weight + + return DefaultWeightsLoader(BNBNF4Weight) + elif quantize == "eetq": + from text_generation_server.layers.eetq import EETQWeight + + return DefaultWeightsLoader(EETQWeight) + elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2WeightsLoader + + return Exl2WeightsLoader() + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeightsLoader + + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + + return MarlinWeightsLoader( + bits=quantizer_config.bits, + is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", + ) + elif quantize == "fp8" or quantize is None: + from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + + # Since the default for the quantize config is _QuantizerConfig, + # we need to add this check to not get an attribute error + activation_scale_ub = None + if isinstance(quantizer_config, _FP8QuantizerConfig): + activation_scale_ub = quantizer_config.activation_scale_ub + + return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") + else: + raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/backends/gaudi/server/text_generation_server/utils/segments.py b/backends/gaudi/server/text_generation_server/utils/segments.py new file mode 100644 index 000000000..f59611021 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/segments.py @@ -0,0 +1,66 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/segments.py +# License: Apache License Version 2.0, January 2004 + +from typing import List, Tuple, Union + +import torch + + +def find_segments( + adapter_indices: Union[torch.Tensor, List[int]] +) -> Tuple[List[int], List[int]]: + segments = [0] + segment_indices = [] + + if isinstance(adapter_indices, torch.Tensor): + # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first + adapter_indices = adapter_indices.cpu().tolist() + + start_index = 0 + for i in range(1, len(adapter_indices)): + if adapter_indices[i] != adapter_indices[i - 1]: + segments.append(i) + segment_indices.append(adapter_indices[i - 1]) + start_index = i + + # Handle the last segment + if start_index < len(adapter_indices): + segments.append(len(adapter_indices)) + segment_indices.append(adapter_indices[-1]) + + return segments, segment_indices + + +class SegmentConcatBuilder: + def __init__(self): + self.adapter_segment_indices = [] + self.adapter_segment_tensors = [] + + def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): + # Update adapter segments + if self.adapter_segment_tensors: + # Because we have already processed at least one batch, remove the 0 start index + # from this batch denoting the beginning of the segment, then offset all segment + # positions by the value of the last segment in the previous batch to account for + # the concatenation. + adapter_segments = ( + adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] + ) + + if ( + self.adapter_segment_indices + and self.adapter_segment_indices[-1] == segment_indices[0] + ): + # If the last segment in the previous batch is the same as the first segment in this batch, + # then we merge them together into a single segment. In effect, this means removing it from + # the segment indices of this batch, and extending the segment span by removing the segment + # end index from the previous batch. + segment_indices = segment_indices[1:] + self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] + + self.adapter_segment_indices.extend(segment_indices) + self.adapter_segment_tensors.append(adapter_segments) + + def build(self) -> Tuple[torch.Tensor, List[int]]: + return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices diff --git a/backends/gaudi/server/text_generation_server/utils/sgmv.py b/backends/gaudi/server/text_generation_server/utils/sgmv.py new file mode 100644 index 000000000..2d0a73a54 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/sgmv.py @@ -0,0 +1,252 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/sgmv.py +# License: Apache License Version 2.0, January 2004 + +import os +import warnings +from functools import lru_cache +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +try: + import punica_kernels as _kernels + + HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) +except ImportError: + warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") + _kernels = None + HAS_SGMV = False + + +MIN_SGMV_RANK = 8 +MIN_RANK_CUSTOM = 16 +MAX_RANK_CUSTOM = 128 +SGMV_BLOCK_SIZE = 16 +BGMV_MAX_RANK = 64 + + +def has_sgmv() -> bool: + return HAS_SGMV + + +def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: + """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" + if not has_sgmv(): + return t + + # tensor parallelism will result in effective rank being divided by world_size, + # so we need to scale the min rank to offset that effect + min_rank = MIN_SGMV_RANK * world_size + + # if we're at or below the min rank, pad up to the min rank + # otherwise, pad to the nearest multiple of the block size + current_rank = t.size(dim) + target_rank = ( + min_rank + if current_rank <= min_rank + else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE + ) + if current_rank == target_rank: + return t + + pad_size = target_rank - current_rank + + # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + pad = [0, 0] * t.dim() + pad[(t.dim() - dim - 1) * 2 + 1] = pad_size + pad = tuple(pad) + + return F.pad(t, pad, mode="constant", value=0.0) + + +def use_cutlass_shrink(lora_rank: int) -> bool: + return lora_rank < MIN_RANK_CUSTOM + + +def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: + if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: + return t.transpose(0, 1) + return t + + +# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py +def add_lora_sgmv_cutlass( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.Tensor, + s_end: torch.Tensor, + layer_idx: int, + lora_rank: int, +): + """ + Semantics: + y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H1]`. + wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H2]`. + s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. + s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. + layer_idx: Layer index of the weight matrices. + """ + if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: + # Custom SGMV shrink only supports rank 16, 32, 64, 128 + _add_lora_sgmv_cutlass_legacy( + y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank + ) + return + + tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) + tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) + + +def _add_lora_sgmv_cutlass_legacy( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +): + tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +@lru_cache(maxsize=1) +def get_tmp_tensor(device: torch.device) -> torch.Tensor: + return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) + + +@lru_cache(maxsize=32) +def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: + tmp_size = _kernels.sgmv_cutlass_tmp_size(size) + return torch.empty((tmp_size,), dtype=torch.uint8, device=device) + + +def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor: + return torch.empty((size,), dtype=torch.uint8, device=device) + + +def get_tmp_expand_size(size: int) -> int: + return _kernels.sgmv_cutlass_tmp_size(size) + + +def get_tmp_tensors( + nsegments: int, lora_rank: int, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: + use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv() + has_sgmv_available = has_sgmv() + + if use_cutlass: + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp + elif has_sgmv_available: + return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device) + else: + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp + + +def lora_a_sgmv_cutlass( + x: torch.Tensor, + tmp: torch.Tensor, + wa_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +) -> torch.Tensor: + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + else: + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + return v + + +def lora_b_sgmv_cutlass( + y: torch.Tensor, + v: torch.Tensor, + tmp: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, +): + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +""" +Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + +Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + v: Shape: `[B, R]`. Temporary vector. + x: Shape: `[B, H1]`. Input vectors. + wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. + wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. +""" + + +def add_lora_a_bgmv( + v: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) + + +def add_lora_b_bgmv( + y: torch.Tensor, + v: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) + + +def segmented_matmul( + y: torch.Tensor, + x: torch.Tensor, + w: List[torch.Tensor], + b: List[torch.Tensor], + s_start: torch.IntTensor, + s_end: torch.IntTensor, +): + for i in range(len(w)): + if s_end[i] - s_start[i] <= 0: + continue + + xi = x[s_start[i] : s_end[i]] + wi = w[i] + bi = b[i] + y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) diff --git a/backends/gaudi/server/text_generation_server/utils/speculate.py b/backends/gaudi/server/text_generation_server/utils/speculate.py new file mode 100644 index 000000000..a1b37a344 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/speculate.py @@ -0,0 +1,11 @@ +SPECULATE = None + + +def get_speculate() -> int: + global SPECULATE + return SPECULATE + + +def set_speculate(speculate: int): + global SPECULATE + SPECULATE = speculate diff --git a/backends/gaudi/server/text_generation_server/utils/tokens.py b/backends/gaudi/server/text_generation_server/utils/tokens.py new file mode 100644 index 000000000..aa4d1fdb9 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/tokens.py @@ -0,0 +1,739 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import os +import re +from typing import List, Optional, Tuple, Set, Union + +import torch +from text_generation_server.pb import generate_pb2 +from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType +from text_generation_server.utils.logits_process import ( + FrequencyPenaltyLogitsProcessor, + GrammarLogitProcessor, + HeterogeneousProcessorWrapper, + HeterogeneousRepetitionPenaltyLogitsProcessor, + HeterogeneousFrequencyPenaltyLogitsProcessor, + HeterogeneousTemperatureLogitsWarper, + HeterogeneousTopKLogitsWarper, + HeterogeneousTopPLogitsWarper, + HeterogeneousTypicalLogitsWarper, + HeterogeneousGrammarLogitProcessor, + static_warper, +) +from text_generation_server.utils.watermark import WatermarkLogitsProcessor +from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor + + +class NextTokenChooser: + def __init__( + self, + watermark: bool = False, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + frequency_penalty: float = 0.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + do_sample: bool = False, + seed: int = 0, + device: str = "cpu", + tokenizer: Optional[PreTrainedTokenizerBase] = None, + grammar: str = "", + grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, + fsm_grammar_state: int = 0, + ): + self.watermark_processor = ( + WatermarkLogitsProcessor(device=device) if watermark else None + ) + self.repetition_processor = ( + RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) + if repetition_penalty and repetition_penalty != 1.0 + else None + ) + self.frequency_processor = ( + FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty) + if frequency_penalty and frequency_penalty != 0.0 + else None + ) + self.grammar_processor = ( + GrammarLogitProcessor(tokenizer, device, grammar, grammar_type) + if grammar != "" + else None + ) + self.tokenizer = tokenizer + + has_warpers = ( + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + or (typical_p is not None and typical_p < 1.0) + ) + if has_warpers: + self.static_warper = static_warper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) + else: + self.static_warper = None + + sampling = do_sample or has_warpers + + self.choice = Sampling(seed, device) if sampling else Greedy() + self.fsm_grammar_state = fsm_grammar_state + self.grammar = grammar + + def __call__(self, input_ids, scores): + if self.watermark_processor is not None: + scores = self.watermark_processor(input_ids, scores) + if self.repetition_processor is not None: + scores = self.repetition_processor(input_ids, scores) + if self.frequency_processor is not None: + scores = self.frequency_processor(input_ids, scores) + if self.grammar_processor is not None: + scores = self.grammar_processor(scores, self.fsm_grammar_state) + + if self.static_warper is None: + next_logprob = torch.log_softmax(scores, -1) + else: + scores, next_logprob = self.static_warper(scores) + + next_id = self.choice(scores[-1]).view(1, 1) + + return next_id, next_logprob + + def advance_grammar(self, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_state = self.grammar_processor.advance( + next_id, self.fsm_grammar_state + ) + return self + + @classmethod + def from_pb( + cls, + pb: generate_pb2.NextTokenChooserParameters, + device: torch.device, + tokenizer: PreTrainedTokenizerBase, + ) -> "NextTokenChooser": + return NextTokenChooser( + watermark=pb.watermark, + temperature=pb.temperature, + repetition_penalty=pb.repetition_penalty, + frequency_penalty=pb.frequency_penalty, + top_k=pb.top_k, + top_p=pb.top_p, + typical_p=pb.typical_p, + do_sample=pb.do_sample, + seed=pb.seed, + device=device, + tokenizer=tokenizer, + grammar=pb.grammar, + grammar_type=pb.grammar_type, + ) + + +class StopSequenceCriteria: + def __init__(self, stop_sequence: str): + stop_sequence = re.escape(stop_sequence) + self.regex = re.compile(f"{stop_sequence}$") + + def __call__(self, output: str) -> bool: + if self.regex.findall(output): + return True + return False + + +class StoppingCriteria: + def __init__( + self, + eos_token_ids: Optional[Union[Set[int], int]], + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens: int = 20, + ignore_eos_token: bool = False, + ): + if eos_token_ids is None: + eos_token_ids = set() + elif isinstance(eos_token_ids, int): + eos_token_ids = set([eos_token_ids]) + elif isinstance(eos_token_ids, set): + eos_token_ids = eos_token_ids + else: + raise RuntimeError( + f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]" + ) + self.eos_token_ids = eos_token_ids + self.stop_sequence_criterias = stop_sequence_criterias + self.max_new_tokens = max_new_tokens + self.current_tokens = 0 + self.current_output = "" + + if os.getenv("TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN", "false") == "true": + self.ignore_eos_token = True + else: + self.ignore_eos_token = ignore_eos_token + + def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: + self.current_tokens += 1 + if self.current_tokens >= self.max_new_tokens: + return True, FinishReason.FINISH_REASON_LENGTH + + if isinstance(last_token, torch.Tensor): + last_token = last_token.item() + + if not self.ignore_eos_token and last_token in self.eos_token_ids: + return True, FinishReason.FINISH_REASON_EOS_TOKEN + + if self.stop_sequence_criterias: + self.current_output += last_output + # There is no need to keep an output that is too long + if len(self.current_output) > 300: + # Slice to -200 to avoid doing it all the time + self.current_output = self.current_output[-200:] + for stop_sequence_criteria in self.stop_sequence_criterias: + if stop_sequence_criteria(self.current_output): + return True, FinishReason.FINISH_REASON_STOP_SEQUENCE + + return False, None + + @classmethod + def from_pb( + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, + ) -> "StoppingCriteria": + stop_sequence_criterias = [ + StopSequenceCriteria(sequence) for sequence in pb.stop_sequences + ] + # TODO Hack because eos_token_id cannot be what we want. + eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id) + return StoppingCriteria( + eos_token_id, + stop_sequence_criterias, + pb.max_new_tokens, + pb.ignore_eos_token, + ) + + +def create_n_gram_speculation( + input_ids: torch.Tensor, + next_ids: torch.Tensor, + accepted_ids: torch.Tensor, + speculate: int, + verbose: bool, +): + # Very trivial approach, find first match in the string. + # This is much less refined than actual n-gram but seems to work + # relatively OK in grounded mode and is by far much faster with + # much less worst case complexity as everything happens on device. + B = accepted_ids.shape[0] + device = input_ids.device + seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1] + indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 + all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange( + speculate, device=device + ) + all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) + + speculative_ids = input_ids.gather(dim=-1, index=all_indices) + return speculative_ids + + +class HeterogeneousNextTokenChooser: + def __init__( + self, + dtype: torch.dtype, + device: torch.device, + watermark: List[bool], + temperature: List[float], + repetition_penalty: List[float], + frequency_penalty: List[float], + top_k: List[int], + top_p: List[float], + typical_p: List[float], + do_sample: List[bool], + seeds: List[int], + tokenizer: PreTrainedTokenizerBase, + grammars: List[str], + grammar_types: List[int], + fsm_grammar_states:List[int], + quantization_enabled: bool, + ): + warpers = [] + + # TODO: enable watermark with FP8 quantization + self.watermark_processor = ( + HeterogeneousProcessorWrapper( + { + i: WatermarkLogitsProcessor(device=device) + for i, do_watermark in enumerate(watermark) + if do_watermark + } + ) + if any(watermark) and not quantization_enabled + else None + ) + + self.repetition_processor = ( + HeterogeneousRepetitionPenaltyLogitsProcessor( + repetition_penalty, dtype, device + ) + if any([x != 1.0 for x in repetition_penalty]) + else None + ) + + self.frequency_processor = ( + HeterogeneousFrequencyPenaltyLogitsProcessor( + frequency_penalty, dtype, device + ) + if any([x != 0.0 for x in frequency_penalty]) + else None + ) + + self.grammar_processor = ( + HeterogeneousGrammarLogitProcessor( + tokenizer, device, grammars, grammar_types + ) + if any([grammar != "" for grammar in grammars]) + else None + ) + + if any(x != 1.0 for x in temperature): + do_sample = [ + sample or x != 1.0 for x, sample in zip(temperature, do_sample) + ] + warpers.append( + HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) + ) + + if any(x != 0 for x in top_k): + do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] + warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) + + if any(x < 1.0 for x in top_p): + do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] + warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) + + if any(x < 1.0 for x in typical_p): + do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] + warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) + + self.warpers = warpers + + if any(do_sample): + self.choice = HeterogeneousSampling(do_sample, seeds, device) + else: + self.choice = Greedy() + + self.seeds = seeds + self.do_sample = do_sample + self.dtype = dtype + self.device = device + self.tokenizer = tokenizer + self.fsm_grammar_states = fsm_grammar_states + self.grammars = grammars + self.grammar_types = grammar_types + + def __call__( + self, + input_ids: torch.Tensor, + scores: torch.Tensor, + speculate: int, + speculated_ids: Optional[torch.Tensor] = None, + speculative_scores: Optional[torch.Tensor] = None, + verbose=False, + ): + if speculated_ids is not None: + B = scores.shape[0] // (speculated_ids.shape[1] + 1) + S = speculated_ids.shape[1] + 1 + scores = scores.view(B, S, -1) + else: + B = scores.shape[0] + S = 1 + scores = scores.view(B, S, -1) + + next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) + + for j in range(S): + _scores = scores[:, j] + if self.watermark_processor is not None: + _scores = self.watermark_processor(input_ids, _scores) + if self.repetition_processor is not None: + _scores = self.repetition_processor(input_ids, _scores) + if self.frequency_processor is not None: + _scores = self.frequency_processor(input_ids, _scores) + if self.grammar_processor is not None: + _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + for warper in self.warpers: + _scores = warper(input_ids, _scores) + _next_ids = self.choice(_scores) + scores[:, j] = _scores + next_ids[:, j] = _next_ids + next_ids = next_ids.view(B * S) + allscores = scores.view(B * S, -1) + alllogprobs = torch.log_softmax(allscores, -1) + + if speculated_ids is not None: + accepted_ids = [] + B = next_ids.shape[0] // (speculated_ids.shape[1] + 1) + S = speculated_ids.shape[1] + 1 + indices = [] + for i in range(B): + _next_ids = next_ids[i * S : (i + 1) * S] + _speculated_ids = speculated_ids[i] + validate_speculative = _next_ids[:-1] == _speculated_ids + index = i * S + accepted = 1 + # First is always valid + indices.append(index) + for valid in validate_speculative.tolist(): + if valid: + index += 1 + accepted += 1 + indices.append(index) + else: + break + accepted_ids.append(accepted) + + accepted_ids = torch.tensor( + accepted_ids, device=input_ids.device, dtype=input_ids.dtype + ) + next_ids = next_ids[indices] + logprobs = alllogprobs[indices] + indices = torch.arange(B, device=input_ids.device) * S + if speculative_scores is not None: + speculative_scores = speculative_scores[indices + accepted_ids - 1] + else: + accepted_ids = torch.ones_like(next_ids) + logprobs = alllogprobs + + next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) + + if speculate > 0: + if speculative_scores is not None: + # Medusa provided some scores + speculative_ids = Greedy()(speculative_scores) + else: + # n-gram + speculative_ids = create_n_gram_speculation( + input_ids, next_ids, accepted_ids, speculate, verbose + ) + else: + speculative_ids = None + + return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids + + def advance_grammar(self, next_ids: List[int]): + if self.grammar_processor is not None: + other_new_states = self.grammar_processor.advance_batch( + next_ids, self.fsm_grammar_states + ) + self.fsm_grammar_states = other_new_states + return self + + def advance_grammar_single(self, grammar_state_index: int, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_states[grammar_state_index] = ( + self.grammar_processor.advance_at_index( + next_id, + self.fsm_grammar_states[grammar_state_index], + grammar_state_index, + ) + ) + return self + + def advance_grammar_single_with_past_state( + self, grammar_state_index: int, next_id: torch.Tensor, past_state: int + ): + if self.grammar_processor is not None: + next_id = next_id.item() + self.fsm_grammar_states[grammar_state_index] = ( + self.grammar_processor.advance_at_index( + next_id, past_state, grammar_state_index, + ) + ) + return self + + def filter(self, indices): + if self.watermark_processor is not None: + self.watermark_processor = self.watermark_processor.filter(indices) + + if self.repetition_processor is not None: + self.repetition_processor = self.repetition_processor.filter(indices) + + if self.frequency_processor is not None: + self.frequency_processor = self.frequency_processor.filter(indices) + + if self.grammar_processor is not None: + self.grammar_processor = self.grammar_processor.filter(indices) + + filtered_warpers = [] + for warper in self.warpers: + filtered_warper = warper.filter(indices) + if filtered_warper is not None: + filtered_warpers.append(filtered_warper) + self.warpers = filtered_warpers + + self.seeds = [self.seeds[i] for i in indices] + self.do_sample = [self.do_sample[i] for i in indices] + + new_grammars = [] + new_fsm_grammar_states = [] + new_grammar_types = [] + for i in indices: + new_grammars.append(self.grammars[i]) + new_fsm_grammar_states.append(self.fsm_grammar_states[i]) + new_grammar_types.append(self.grammar_types[i]) + + self.grammars = new_grammars + self.fsm_grammar_states = new_fsm_grammar_states + self.grammar_types = new_grammar_types + + if any(self.do_sample): + self.choice.filter(indices) + else: + self.choice = Greedy() + + return self + + @classmethod + def from_pb( + cls, + pb: List[generate_pb2.NextTokenChooserParameters], + dtype: torch.dtype, + device: torch.device, + tokenizer: PreTrainedTokenizerBase, + fsm_grammar_states: Optional[List[int]] = None, + quantization_enabled: bool = False, + ) -> "HeterogeneousNextTokenChooser": + return HeterogeneousNextTokenChooser( + watermark=[pb_.watermark for pb_ in pb], + temperature=[pb_.temperature for pb_ in pb], + repetition_penalty=[pb_.repetition_penalty for pb_ in pb], + frequency_penalty=[pb_.frequency_penalty for pb_ in pb], + top_k=[pb_.top_k for pb_ in pb], + top_p=[pb_.top_p for pb_ in pb], + typical_p=[pb_.typical_p for pb_ in pb], + do_sample=[pb_.do_sample for pb_ in pb], + seeds=[pb_.seed for pb_ in pb], + device=device, + dtype=dtype, + tokenizer=tokenizer, + grammars=[pb_.grammar for pb_ in pb], + grammar_types=[pb_.grammar_type for pb_ in pb], + fsm_grammar_states=( + fsm_grammar_states if fsm_grammar_states else [0] * len(pb) + ), + quantization_enabled=quantization_enabled, + ) + + +def pad_next_token_chooser_parameters( + parameters: List[generate_pb2.NextTokenChooserParameters], + expected_size: int, +) -> List[generate_pb2.NextTokenChooserParameters]: + # disable all logits processors to minimize padding overhead + empty_parameters = generate_pb2.NextTokenChooserParameters( + temperature=1.0, + top_k=0, + top_p=1.0, + typical_p=1.0, + do_sample=False, + seed=0, + repetition_penalty=1.0, + frequency_penalty=0.0, + watermark=False, + grammar="", + grammar_type=0, + ) + parameters.extend( + [empty_parameters] * (expected_size - len(parameters)) + ) + return parameters + + +class Sampling: + def __init__(self, seed: int, device: str = "cpu"): + self.generator = torch.Generator("cpu") + self.generator.manual_seed(seed) + self.seed = seed + + def __call__(self, logits): + probs = torch.nn.functional.softmax(logits, -1) + # Avoid GPU<->CPU sync done by torch multinomial + # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 + q = torch.empty_like(probs).exponential_(1, generator=self.generator) + return probs.div_(q).argmax() + + +class Greedy: + def __call__(self, logits): + return logits.argmax(dim=-1) + + +class HeterogeneousSampling: + r""" + Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. + """ + + def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): + self.seeds = seeds + + self.greedy_indices = [] + self.sampling_mapping = {} + for i, (sample, seed) in enumerate(zip(do_sample, seeds)): + if sample: + self.sampling_mapping[i] = Sampling(seed, device) + else: + self.greedy_indices.append(i) + + self.greedy = Greedy() + + def __call__(self, logits): + out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device) + if self.greedy_indices: + # Computing for all indices is faster than slicing + torch.argmax(logits, -1, out=out) + + for i, sampling in self.sampling_mapping.items(): + out[i] = sampling(logits[i]) + return out + + def filter(self, indices): + new_greedy_indices = [] + new_sampling_mapping = {} + for i, idx in enumerate(indices): + if idx in self.sampling_mapping: + new_sampling_mapping[i] = self.sampling_mapping[idx] + else: + new_greedy_indices.append(i) + + self.greedy_indices = new_greedy_indices + self.sampling_mapping = new_sampling_mapping + return self + + +def batch_top_tokens( + top_n_tokens: List[int], + top_n_tokens_tensor: torch.Tensor, + logprobs: torch.Tensor, + accepted_ids: torch.Tensor, +) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: + """Find the top n most likely tokens for a batch of generations. + + When multiple tokens have equal probabilities and they don't all fit, the + remaining tokens are also returned. + """ + max_top_n = max(top_n_tokens) + # Early exit when top_n_tokens is not used + if max_top_n == 0: + return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens) + + batch_size = accepted_ids.shape[0] + speculate_size = logprobs.shape[0] // batch_size + top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size) + # Ensure top_n doesn't exceed vocab size + top_n_tokens = [ + min(tok, logprobs.size(-1)) + for tok in top_n_tokens + for _ in range(speculate_size) + ] + + # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 + # Sorted topk is faster than torch.sort() since we only need a small subset + sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values + + nth_highest = torch.gather( + sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1) + ) + nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min + + # Find the new "fuzzy" top n values + top_n_indices = (logprobs >= nth_highest).nonzero() + _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) + + k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max() + # Take a new topk for these new max n values + top_k = torch.topk(logprobs, k=k, dim=1, sorted=True) + + top_n_ishes = top_n_ishes.tolist() + top_indices = top_k.indices.tolist() + top_values = top_k.values.tolist() + + batch_top_token_ids = [] + batch_top_token_logprobs = [] + accepted_ids_list = accepted_ids.tolist() + for i, n_accepted_ids in enumerate(accepted_ids_list): + start = speculate_size * i + stop = speculate_size * (i + 1) + _top_indices = top_indices[start:stop] + _top_values = top_values[start:stop] + _top_n_ishes = top_n_ishes[start:stop] + _top_n_tokens = top_n_tokens[start:stop] + + _top_indices = _top_indices[:n_accepted_ids] + _top_values = _top_values[:n_accepted_ids] + _top_n_ishes = _top_n_ishes[:n_accepted_ids] + _top_n_tokens = _top_n_tokens[:n_accepted_ids] + + row_top_token_ids = [] + row_top_token_logprobs = [] + + for idxs, vals, n, req_n in zip( + _top_indices, _top_values, _top_n_ishes, _top_n_tokens + ): + indices = idxs[:n] if req_n > 0 else [] + values = vals[:n] if req_n > 0 else [] + + row_top_token_ids.append(indices) + row_top_token_logprobs.append(values) + + batch_top_token_ids.append(row_top_token_ids) + batch_top_token_logprobs.append(row_top_token_logprobs) + + return batch_top_token_ids, batch_top_token_logprobs + + +def make_tokenizer_optional(tokenizer): + class _(type(tokenizer)): + def __call__( + self, + text, + return_tensors, + padding, + return_token_type_ids, + truncation, + max_length + ): + assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" + assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer" + assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer" + assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer" + + def str_token_to_int(i): + if i == '?': + return tokenizer.pad_token_id + else: + return int(i) + all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] + for inner_text in text] + if padding == "longest": + max_length = max(len(tokens) for tokens in all_tokens) + return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]), + "attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])} + + def decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + return ','.join(str(i) for i in to_py_obj(token_ids)) + + import os + if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": + tokenizer.__class__ = _ + tokenizer.is_transparent = True + + +def is_tokenizer_transparent(tokenizer): + return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True diff --git a/backends/gaudi/server/text_generation_server/utils/version.py b/backends/gaudi/server/text_generation_server/utils/version.py new file mode 100644 index 000000000..a72a9ea7b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/version.py @@ -0,0 +1,12 @@ +from optimum.habana.utils import get_driver_version +from packaging.version import Version + +MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0") + + +def is_driver_compatible(): + driver_version = get_driver_version() + if driver_version is not None: + if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION: + return False + return True \ No newline at end of file diff --git a/backends/gaudi/server/text_generation_server/utils/watermark.py b/backends/gaudi/server/text_generation_server/utils/watermark.py new file mode 100644 index 000000000..5092b076c --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/watermark.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2023 Authors of "A Watermark for Large Language Models" +# available at https://arxiv.org/abs/2301.10226 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from transformers import LogitsProcessor +from typing import List, Union + +GAMMA = float(os.getenv("WATERMARK_GAMMA", 0.5)) +DELTA = float(os.getenv("WATERMARK_DELTA", 2.0)) + + +class WatermarkLogitsProcessor(LogitsProcessor): + def __init__( + self, + gamma: float = GAMMA, + delta: float = DELTA, + hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width + device: str = "cpu", + ): + # watermarking parameters + self.gamma = gamma + self.delta = delta + self.rng = torch.Generator(device="cpu") + self.hash_key = hash_key + + def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): + if isinstance(input_ids, list): + assert ( + len(input_ids) >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1] + else: + assert len(input_ids) == 1 + input_ids = input_ids[0] + assert ( + input_ids.shape[-1] >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1].item() + self.rng.manual_seed(self.hash_key * prev_token) + + def _get_greenlist_ids( + self, + input_ids: Union[List[int], torch.LongTensor], + max_value: int, + device: torch.device, + ) -> List[int]: + # seed the rng using the previous tokens/prefix + self._seed_rng(input_ids) + + greenlist_size = int(max_value * self.gamma) + vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng) + greenlist_ids = vocab_permutation[:greenlist_size] + return greenlist_ids + + @staticmethod + def _calc_greenlist_mask( + scores: torch.FloatTensor, greenlist_token_ids + ) -> torch.BoolTensor: + green_tokens_mask = torch.zeros_like(scores) + green_tokens_mask[-1, greenlist_token_ids] = 1 + final_mask = green_tokens_mask.bool() + return final_mask + + @staticmethod + def _bias_greenlist_logits( + scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float + ) -> torch.Tensor: + scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias + return scores + + def __call__( + self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor + ) -> torch.FloatTensor: + greenlist_ids = self._get_greenlist_ids( + input_ids, scores.shape[-1], scores.device + ) + green_tokens_mask = self._calc_greenlist_mask( + scores=scores, greenlist_token_ids=greenlist_ids + ) + + scores = self._bias_greenlist_logits( + scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta + ) + return scores diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py new file mode 100644 index 000000000..75e01f7ce --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -0,0 +1,437 @@ +import torch + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, List, Optional, Union, Type +from safetensors import safe_open +from dataclasses import dataclass + +from text_generation_server.utils.import_utils import SYSTEM + + +class WeightsLoader(ABC): + """ + Instances of this type implement higher-level weight loading. + + At a low-level, every weight is stored in the Safetensors format. + The interpretation of weights may be different however, for instance + could be packed, quantized weights. Loaders are responsible for + interpreting the raw tensors, sharding tensors in a manner compatible + with the format, etc. + """ + + @abstractmethod + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + ... + + @abstractmethod + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + """ + Get the packed weights at the given prefix with column-splitting for + tensor parallelism. This method should be used when multiple different + weights are packed into a tensor, for instance, query/key/value + weights or a gate/up projection. + + The `block_sizes` determines the proportions of the packed tensors. + The columns are split in equally sized blocks when `block_sizes` is an + `int`, or in blocks proportional given to the sizes. For instance + `[2, 1, 1]` will divide an input with dimensionality `1024` in + `[512, 256, 256]`. + """ + ... + + def get_weights_col(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply column-splitting for tensor + paralllism. + """ + return weights.get_multi_weights_col([prefix], 0) + + @abstractmethod + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + + @abstractmethod + def get_weights_row(self, weights: "Weights", prefix: str): + """ + Get the weights at the given prefix and apply row-splitting for tensor + parallism. + """ + ... + + +class Weight(ABC): + """Instances of this type implement unquantized/quantized/to-be + quantized weights.""" + + @abstractmethod + def get_linear(self, bias: torch.Tensor): + """Create a linear layer from this weight.""" + ... + + +@dataclass +class UnquantizedWeight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.linear import FastLinear, FastLinearROCm + + if SYSTEM == "rocm": + return FastLinearROCm(self.weight, bias) + else: + return FastLinear(self.weight, bias) + + +class DefaultWeightsLoader(WeightsLoader): + """Weight loader that loads (unquantized) Torch tensors.""" + + def __init__(self, weight_class: Type[UnquantizedWeight]): + """Create a loader. Weights will be wrapped using the given `weights_class`, + normally this will be `UnquantizedWeight`, but a quantizer-specific class + such as `Fp8Weight` can be used to quantize the weights during loading. + """ + self.weight_class = weight_class + + """ + Loader that uses tensors as-is with the exception of applying sharding + and/or concatenation. + """ + + def get_weights(self, weights: "Weights", prefix: str): + return weights.get_tensor(f"{prefix}.weight") + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + return self.weight_class( + weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ), + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + return self.weight_class(torch.cat(w, dim=dim)) + + def get_weights_row(self, weights: "Weights", prefix: str): + return self.weight_class( + weights.get_sharded(f"{prefix}.weight", dim=1), + ) + + +class Weights: + def __init__( + self, + filenames: List[Path], + device, + dtype, + process_group, + weights_loader: WeightsLoader, + aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None, + ): + routing = {} + for filename in filenames: + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self.prefix = prefix + self.weights_loader = weights_loader + self._handles = {} + + def _get_handle(self, filename): + if filename not in self._handles: + f = safe_open(filename, framework="pytorch") + self._handles[filename] = f + + return self._handles[filename] + + def get_filename(self, tensor_name: str) -> (str, str): + names = [tensor_name] + if self.prefix is not None: + prefixed = f"{self.prefix}.{tensor_name}" + names.append(prefixed) + for name in names: + filename = self.routing.get(name, None) + if filename is not None: + return str(filename), name + + aliases = self.aliases.get(name, []) + for alias in aliases: + filename = self.routing.get(alias, None) + if filename is not None: + return str(filename), alias + raise RuntimeError(f"weight {tensor_name} does not exist") + + def _get_slice(self, tensor_name: str): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + return slice_ + + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + + def get_shape(self, tensor_name: str): + return self._get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32. Exl2 uses int16 + # as well. FP8 uses torch.float8_e4m3fn + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): + tensor = tensor.to(dtype=self.dtype) + if to_device: + tensor = tensor.to(device=self.device) + return tensor + + def get_partial_sharded( + self, tensor_name: str, dim: int, to_device=True, to_dtype=True + ): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + rank = self.process_group.rank() + + size = slice_.get_shape()[dim] + block_size = (size + world_size - 1) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + if dim == 0: + tensor = slice_[start:stop] + elif dim == 1: + tensor = slice_[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32. exl2 uses int16. + # FP8 uses torch.float8_e4m3fn. + if ( + tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + and to_dtype + ): + tensor = tensor.to(dtype=self.dtype) + if to_device: + tensor = tensor.to(device=self.device) + return tensor + + def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded( + tensor_name, dim, to_device=to_device, to_dtype=to_dtype + ) + + def get_packed_sharded( + self, + tensor_name: str, + dim: int, + block_sizes: Union[int, List[int]], + to_dtype=True, + ) -> torch.Tensor: + """ + Get a shard from a tensor that packs multiple tensors. + + When a tensor packs multiple tensors (such as QKV or an up + projection + gate projection), sharding with `get_sharded` is not + safe since it would not split the packed tensors across shards. + + This method shards a tensor, such that the packed tensors are + split across shards. + + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. + """ + slice_ = self._get_slice(tensor_name) + total_size = slice_.get_shape()[dim] + block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes) + + world_size = self.process_group.size() + rank = self.process_group.rank() + + tensors = [] + block_offset = 0 + for block_size in block_sizes: + assert ( + block_size % world_size == 0 + ), f"Prepacked tensor cannot be sharded across {world_size} shards" + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + if dim == 0: + tensor = slice_[block_offset + start : block_offset + stop] + elif dim == 1: + tensor = slice_[:, block_offset + start : block_offset + stop] + else: + raise NotImplementedError("Currently only dim=0 or dim=1 is supported") + tensors.append(tensor) + block_offset += block_size + tensor = torch.cat(tensors, dim=dim) + tensor = tensor.to(device=self.device) + + # Avoid casting quantizer dtypes. + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): + tensor = tensor.to(dtype=self.dtype) + + return tensor + + def get_weights(self, prefix: str): + return self.weights_loader.get_weights(self, prefix) + + def get_weights_col_packed_qkv( + self, + prefix: str, + num_heads: int, + num_key_value_heads: int, + ): + return self.get_weights_col_packed( + prefix, [num_heads, num_key_value_heads, num_key_value_heads] + ) + + def get_weights_col_packed_gate_up(self, prefix: str): + return self.get_weights_col_packed(prefix, 2) + + def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]): + """ + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. + """ + return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes) + + def get_weights_col(self, prefix: str): + return self.weights_loader.get_weights_col(self, prefix) + + def get_multi_weights_col(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights_col(self, prefixes, dim) + + def get_tensor_shard(self, var, dim): + world_size = self.process_group.size() + rank = self.process_group.rank() + block_size = var.size()[dim] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + if dim == 0: + tensor = var[start:stop] + elif dim == 1: + tensor = var[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_weights_row(self, prefix: str): + return self.weights_loader.get_weights_row(self, prefix) + + @contextmanager + def use_loader(self, weights_loader: WeightsLoader): + """ + This method is a context manager that can be used to use `Weights` with + a different loader for the duration of the context. + """ + + old_loader = self.weights_loader + self.weights_loader = weights_loader + try: + yield + finally: + self.weights_loader = old_loader + + @property + def loader(self): + return self.weights_loader + + +def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: + """ + Convert block count or proportions to block sizes. + + This function accepts + + - The number of blocks (int), in which case the block size is + total_size//blocks; or + - A list of block sizes (List[int]). + + In the latter case, if sum(blocks) < total_size, the ratios between + the block sizes will be preserved. For instance, if blocks is + [2, 1, 1] and total_size is 1024, the returned block sizes are + [512, 256, 256]. + """ + if isinstance(blocks, list): + total_blocks = sum(blocks) + assert ( + total_size % total_blocks == 0 + ), f"Cannot split {total_size} in proportional blocks: {blocks}" + part_size = total_size // total_blocks + return [part_size * block for block in blocks] + else: + assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" + single_size = total_size // blocks + return [single_size] * blocks