mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into ci_amd3
This commit is contained in:
commit
227f78f3fe
4
.github/workflows/build.yaml
vendored
4
.github/workflows/build.yaml
vendored
@ -204,6 +204,8 @@ jobs:
|
|||||||
needs: [build-and-push, prepare_integration_tests]
|
needs: [build-and-push, prepare_integration_tests]
|
||||||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
||||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||||
|
env:
|
||||||
|
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@ -250,4 +252,4 @@ jobs:
|
|||||||
echo "DOCKER_VOLUME:"
|
echo "DOCKER_VOLUME:"
|
||||||
echo $DOCKER_VOLUME
|
echo $DOCKER_VOLUME
|
||||||
|
|
||||||
pytest -s -vvvvv integration-tests
|
pytest -s -vvvvv integration-tests ${PYTEST_FLAGS}
|
||||||
|
2
.github/workflows/client-tests.yaml
vendored
2
.github/workflows/client-tests.yaml
vendored
@ -22,5 +22,5 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest pytest-asyncio
|
pip install pytest pytest-asyncio
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
make python-client-tests
|
make python-client-tests
|
||||||
|
2
.github/workflows/integration_tests.yaml
vendored
2
.github/workflows/integration_tests.yaml
vendored
@ -37,5 +37,5 @@ jobs:
|
|||||||
export DOCKER_VOLUME=/mnt/cache
|
export DOCKER_VOLUME=/mnt/cache
|
||||||
export DOCKER_IMAGE=${{ inputs.docker_image }}
|
export DOCKER_IMAGE=${{ inputs.docker_image }}
|
||||||
export DOCKER_DEVICES=${{ inputs.docker_devices }}
|
export DOCKER_DEVICES=${{ inputs.docker_devices }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
2
.github/workflows/load_test.yaml
vendored
2
.github/workflows/load_test.yaml
vendored
@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Start starcoder
|
- name: Start starcoder
|
||||||
run: |
|
run: |
|
||||||
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
||||||
sleep 10
|
sleep 10
|
||||||
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
||||||
|
|
||||||
|
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@ -72,7 +72,7 @@ jobs:
|
|||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Pre-commit checks
|
- name: Pre-commit checks
|
||||||
run: |
|
run: |
|
||||||
|
10
Dockerfile
10
Dockerfile
@ -145,6 +145,13 @@ COPY server/marlin/ .
|
|||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||||
|
|
||||||
|
# Build Lorax Punica kernels
|
||||||
|
FROM kernel-builder as lorax-punica-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/Makefile-lorax-punica Makefile
|
||||||
|
# Build specific version of transformers
|
||||||
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
# Build Transformers CUDA kernels
|
||||||
FROM kernel-builder as custom-kernels-builder
|
FROM kernel-builder as custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
|
|||||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from marlin kernels builder
|
# Copy build artifacts from marlin kernels builder
|
||||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
|||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
CMD ["--json-output"]
|
# CMD ["--json-output"]
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
@ -37,7 +39,8 @@ RUN cargo build --profile release-opt
|
|||||||
|
|
||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base
|
|
||||||
|
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# Final image
|
|
||||||
FROM base
|
|
||||||
|
|
||||||
|
# Text Generation Inference base image for Intel-cpu
|
||||||
|
FROM ubuntu:22.04 as cpu
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
ca-certificates \
|
||||||
|
make \
|
||||||
|
g++ \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
cmake
|
||||||
|
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
|
ARG PYTHON_VERSION='3.10.10'
|
||||||
|
# Automatically set by buildx
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
# Install mamba
|
||||||
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
|
esac && \
|
||||||
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
||||||
|
|
||||||
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
||||||
|
|
||||||
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
||||||
|
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||||
|
|
||||||
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
|
||||||
|
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||||
|
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
||||||
|
ENV KMP_BLOCKTIME=1
|
||||||
|
ENV KMP_TPAUSE=0
|
||||||
|
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements_intel.txt && \
|
||||||
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
FROM ${PLATFORM} as final
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
@ -153,7 +153,7 @@ this will impact performance.
|
|||||||
### Distributed Tracing
|
### Distributed Tracing
|
||||||
|
|
||||||
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
||||||
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
|
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
|
||||||
overridden with the `--otlp-service-name` argument
|
overridden with the `--otlp-service-name` argument
|
||||||
|
|
||||||
### Architecture
|
### Architecture
|
||||||
|
@ -157,6 +157,7 @@ async fn prefill(
|
|||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
tracing::info!("Downloading tokenizer");
|
tracing::info!("Downloading tokenizer");
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
// Parse Huggingface hub token
|
||||||
let auth_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
|
let auth_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
|
@ -455,6 +455,6 @@ class DeployedModel(BaseModel):
|
|||||||
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
|
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
|
||||||
# with model_ prefixes, since this disables guardrails for colliding fields:
|
# with model_ prefixes, since this disables guardrails for colliding fields:
|
||||||
# https://github.com/pydantic/pydantic/issues/9177
|
# https://github.com/pydantic/pydantic/issues/9177
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
model_id: str
|
model_id: str
|
||||||
sha: str
|
sha: str
|
||||||
|
@ -60,6 +60,9 @@
|
|||||||
- local: conceptual/speculation
|
- local: conceptual/speculation
|
||||||
title: Speculation (Medusa, ngram)
|
title: Speculation (Medusa, ngram)
|
||||||
- local: conceptual/guidance
|
- local: conceptual/guidance
|
||||||
title: How Guidance Works (via outlines)
|
title: How Guidance Works (via outlines
|
||||||
|
- local: conceptual/lora
|
||||||
|
title: LoRA (Low-Rank Adaptation)
|
||||||
|
|
||||||
|
|
||||||
title: Conceptual Guides
|
title: Conceptual Guides
|
||||||
|
@ -416,6 +416,14 @@ Options:
|
|||||||
[env: MAX_CLIENT_BATCH_SIZE=]
|
[env: MAX_CLIENT_BATCH_SIZE=]
|
||||||
[default: 4]
|
[default: 4]
|
||||||
|
|
||||||
|
```
|
||||||
|
## LORA_ADAPTERS
|
||||||
|
```shell
|
||||||
|
--lora-adapters <LORA_ADAPTERS>
|
||||||
|
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request
|
||||||
|
|
||||||
|
[env: LORA_ADAPTERS=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## HELP
|
## HELP
|
||||||
```shell
|
```shell
|
||||||
|
65
docs/source/conceptual/lora.md
Normal file
65
docs/source/conceptual/lora.md
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# LoRA (Low-Rank Adaptation)
|
||||||
|
|
||||||
|
## What is LoRA?
|
||||||
|
|
||||||
|
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
|
||||||
|
|
||||||
|
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
|
||||||
|
|
||||||
|
## How is it used?
|
||||||
|
|
||||||
|
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
|
||||||
|
|
||||||
|
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
|
||||||
|
|
||||||
|
- fine-tuning a language model on a small dataset
|
||||||
|
- fine-tuning a language model on a domain-specific dataset
|
||||||
|
- fine-tuning a language model on a dataset with limited labels
|
||||||
|
|
||||||
|
## Optimizing Inference with LoRA
|
||||||
|
|
||||||
|
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
|
||||||
|
|
||||||
|
## Serving multiple LoRA adapters with TGI
|
||||||
|
|
||||||
|
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
|
||||||
|
|
||||||
|
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
|
||||||
|
|
||||||
|
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
|
||||||
|
|
||||||
|
### Specifying LoRA models
|
||||||
|
|
||||||
|
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||||
|
```
|
||||||
|
|
||||||
|
In the server logs, you will see the following message:
|
||||||
|
|
||||||
|
```txt
|
||||||
|
Loading adapter weights into model: predibase/customer_support
|
||||||
|
Loading adapter weights into model: predibase/dbpedia
|
||||||
|
```
|
||||||
|
|
||||||
|
## Generate text
|
||||||
|
|
||||||
|
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
|
||||||
|
|
||||||
|
```json
|
||||||
|
curl 127.0.0.1:3000/generate \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"inputs": "Hello who are you?",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 40,
|
||||||
|
"adapter_id": "predibase/customer_support"
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
||||||
|
|
||||||
|
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
@ -43,6 +43,26 @@ if SYSTEM is None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--release", action="store_true", default=False, help="run release tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.addinivalue_line("markers", "release: mark test as a release-only test")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
if config.getoption("--release"):
|
||||||
|
# --release given in cli: do not skip release tests
|
||||||
|
return
|
||||||
|
skip_release = pytest.mark.skip(reason="need --release option to run")
|
||||||
|
for item in items:
|
||||||
|
if "release" in item.keywords:
|
||||||
|
item.add_marker(skip_release)
|
||||||
|
|
||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
rtol = 0.2
|
rtol = 0.2
|
||||||
ignore_logprob = False
|
ignore_logprob = False
|
||||||
|
@ -15,6 +15,7 @@ async def bloom_560(bloom_560_handle):
|
|||||||
return bloom_560_handle.client
|
return bloom_560_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m(bloom_560, response_snapshot):
|
async def test_bloom_560m(bloom_560, response_snapshot):
|
||||||
@ -31,6 +32,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
||||||
@ -55,6 +57,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
|
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
|
||||||
|
@ -15,6 +15,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
|
|||||||
return bloom_560m_sharded_handle.client
|
return bloom_560m_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
||||||
@ -31,6 +32,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bloom_560m_sharded_load(
|
async def test_bloom_560m_sharded_load(
|
||||||
bloom_560m_sharded, generate_load, response_snapshot
|
bloom_560m_sharded, generate_load, response_snapshot
|
||||||
|
@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
|
|||||||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_single_prompt(
|
def test_flash_llama_completion_single_prompt(
|
||||||
flash_llama_completion, response_snapshot
|
flash_llama_completion, response_snapshot
|
||||||
):
|
):
|
||||||
@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{flash_llama_completion.base_url}/v1/completions",
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||||||
@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
async def test_flash_llama_completion_many_prompts_stream(
|
async def test_flash_llama_completion_many_prompts_stream(
|
||||||
flash_llama_completion, response_snapshot
|
flash_llama_completion, response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -27,6 +27,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
|
|||||||
return flash_llama_awq_handle.client
|
return flash_llama_awq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
@ -41,6 +42,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
@ -62,6 +64,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -26,6 +26,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
|||||||
|
|
||||||
|
|
||||||
@is_flaky_async(max_attempts=5)
|
@is_flaky_async(max_attempts=5)
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda", "rocm")
|
@require_backend_async("cuda", "rocm")
|
||||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||||
@ -47,6 +48,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
|||||||
|
|
||||||
|
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_load_sharded(
|
async def test_flash_llama_awq_load_sharded(
|
||||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||||
@ -57,6 +59,7 @@ async def test_flash_llama_awq_load_sharded(
|
|||||||
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
|
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
assert all(
|
assert all(
|
||||||
[
|
[
|
||||||
r.generated_text
|
r.generated_text
|
||||||
|
@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
|
|||||||
return flash_falcon_handle.client
|
return flash_falcon_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_falcon(flash_falcon, response_snapshot):
|
async def test_flash_falcon(flash_falcon, response_snapshot):
|
||||||
@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
||||||
@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
|
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
|
||||||
|
@ -17,6 +17,7 @@ async def flash_gemma(flash_gemma_handle):
|
|||||||
return flash_gemma_handle.client
|
return flash_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
@ -29,6 +30,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
@ -53,6 +55,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
|
@ -17,6 +17,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
|||||||
return flash_gemma_gptq_handle.client
|
return flash_gemma_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
@ -31,6 +32,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
@ -57,6 +59,7 @@ async def test_flash_gemma_gptq_all_params(
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
|
@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
|
|||||||
return flash_gpt2_handle.client
|
return flash_gpt2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
||||||
response = await flash_gpt2.generate(
|
response = await flash_gpt2.generate(
|
||||||
@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
|
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -23,6 +23,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
|
|||||||
|
|
||||||
|
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||||
@ -35,6 +36,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
|||||||
|
|
||||||
|
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_all_params(
|
async def test_flash_llama_exl2_all_params(
|
||||||
@ -62,6 +64,7 @@ async def test_flash_llama_exl2_all_params(
|
|||||||
|
|
||||||
|
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_load(
|
async def test_flash_llama_exl2_load(
|
||||||
|
@ -15,6 +15,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
|
|||||||
return flash_llama_gptq_handle.client
|
return flash_llama_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@is_flaky_async(max_attempts=5)
|
@is_flaky_async(max_attempts=5)
|
||||||
@ -31,6 +32,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
@ -55,6 +57,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
|
@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
|||||||
return flash_llama_gptq_marlin_handle.client
|
return flash_llama_gptq_marlin_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
||||||
@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_marlin_all_params(
|
async def test_flash_llama_gptq_marlin_all_params(
|
||||||
@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_marlin_load(
|
async def test_flash_llama_gptq_marlin_load(
|
||||||
|
@ -20,6 +20,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
|
|||||||
return flash_llama_marlin_handle.client
|
return flash_llama_marlin_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
||||||
@ -31,6 +32,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
|
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
|
||||||
@ -53,6 +55,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_marlin_load(
|
async def test_flash_llama_marlin_load(
|
||||||
|
@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
|
|||||||
return flash_neox_handle.client
|
return flash_neox_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox(flash_neox, response_snapshot):
|
async def test_flash_neox(flash_neox, response_snapshot):
|
||||||
@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
|
|||||||
return flash_neox_sharded_handle.client
|
return flash_neox_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
||||||
response = await flash_neox_sharded.generate(
|
response = await flash_neox_sharded.generate(
|
||||||
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
|
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -38,6 +38,7 @@ def get_cow_beach():
|
|||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
@ -50,6 +51,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
||||||
|
@ -17,6 +17,7 @@ async def flash_phi(flash_phi_handle):
|
|||||||
return flash_phi_handle.client
|
return flash_phi_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_flash_phi(flash_phi, response_snapshot):
|
async def test_flash_phi(flash_phi, response_snapshot):
|
||||||
@ -29,6 +30,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||||
@ -53,6 +55,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
|
|||||||
return flash_qwen2_handle.client
|
return flash_qwen2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
||||||
response = await flash_qwen2.generate(
|
response = await flash_qwen2.generate(
|
||||||
@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
||||||
response = await flash_qwen2.generate(
|
response = await flash_qwen2.generate(
|
||||||
@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
|
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
|
||||||
|
@ -15,6 +15,7 @@ async def flash_santacoder(flash_santacoder_handle):
|
|||||||
return flash_santacoder_handle.client
|
return flash_santacoder_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
||||||
@ -27,6 +28,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_santacoder_load(
|
async def test_flash_santacoder_load(
|
||||||
flash_santacoder, generate_load, response_snapshot
|
flash_santacoder, generate_load, response_snapshot
|
||||||
|
@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle):
|
|||||||
return flash_starcoder_handle.client
|
return flash_starcoder_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
||||||
@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
||||||
@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):
|
async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle):
|
|||||||
return flash_starcoder2_handle.client
|
return flash_starcoder2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||||
@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||||
@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder2_load(
|
async def test_flash_starcoder2_load(
|
||||||
|
@ -15,6 +15,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
|||||||
return flash_starcoder_gptq_handle.client
|
return flash_starcoder_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@is_flaky_async(max_attempts=10)
|
@is_flaky_async(max_attempts=10)
|
||||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||||
@ -33,6 +34,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
|||||||
assert response == generous_response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@is_flaky_async(max_attempts=10)
|
@is_flaky_async(max_attempts=10)
|
||||||
async def test_flash_starcoder_gptq_default_params(
|
async def test_flash_starcoder_gptq_default_params(
|
||||||
@ -55,6 +57,7 @@ async def test_flash_starcoder_gptq_default_params(
|
|||||||
assert response == generous_response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_flash_starcoder_gptq_load(
|
async def test_flash_starcoder_gptq_load(
|
||||||
|
@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
|||||||
return non_flash_llama_grammar_handle.client
|
return non_flash_llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
||||||
|
@ -22,6 +22,7 @@ async def llama_grammar(llama_grammar_handle):
|
|||||||
return llama_grammar_handle.client
|
return llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
|
||||||
@ -62,6 +63,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
|
|||||||
assert chat_completion == response_snapshot
|
assert chat_completion == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||||
llama_grammar,
|
llama_grammar,
|
||||||
|
@ -45,6 +45,7 @@ async def test_idefics(idefics, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_idefics_two_images(idefics, response_snapshot):
|
async def test_idefics_two_images(idefics, response_snapshot):
|
||||||
@ -60,6 +61,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||||
chicken = get_chicken()
|
chicken = get_chicken()
|
||||||
|
@ -28,6 +28,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
|||||||
return flash_llava_next_handle.client
|
return flash_llava_next_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||||
@ -43,6 +44,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||||
@ -66,6 +68,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_load(
|
async def test_flash_llava_next_load(
|
||||||
|
@ -15,6 +15,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
|||||||
return fused_kernel_mamba_handle.client
|
return fused_kernel_mamba_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
@ -27,6 +28,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
@ -54,6 +56,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@require_backend_async("cuda")
|
@require_backend_async("cuda")
|
||||||
async def test_mamba_load(
|
async def test_mamba_load(
|
||||||
|
@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle):
|
|||||||
return mpt_sharded_handle.client
|
return mpt_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mpt(mpt_sharded, response_snapshot):
|
async def test_mpt(mpt_sharded, response_snapshot):
|
||||||
response = await mpt_sharded.generate(
|
response = await mpt_sharded.generate(
|
||||||
@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
|
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -14,6 +14,7 @@ async def mt0_base(mt0_base_handle):
|
|||||||
return mt0_base_handle.client
|
return mt0_base_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mt0_base(mt0_base, response_snapshot):
|
async def test_mt0_base(mt0_base, response_snapshot):
|
||||||
response = await mt0_base.generate(
|
response = await mt0_base.generate(
|
||||||
@ -28,6 +29,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||||
response = await mt0_base.generate(
|
response = await mt0_base.generate(
|
||||||
@ -50,6 +52,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
|
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -15,6 +15,7 @@ async def neox(neox_handle):
|
|||||||
return neox_handle.client
|
return neox_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox(neox, response_snapshot):
|
async def test_neox(neox, response_snapshot):
|
||||||
@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox_load(neox, generate_load, response_snapshot):
|
async def test_neox_load(neox, generate_load, response_snapshot):
|
||||||
|
@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle):
|
|||||||
return neox_sharded_handle.client
|
return neox_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox(neox_sharded, response_snapshot):
|
async def test_neox(neox_sharded, response_snapshot):
|
||||||
@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
|
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle):
|
|||||||
return t5_sharded_handle.client
|
return t5_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_t5_sharded(t5_sharded, response_snapshot):
|
async def test_t5_sharded(t5_sharded, response_snapshot):
|
||||||
response = await t5_sharded.generate(
|
response = await t5_sharded.generate(
|
||||||
@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
|
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -452,6 +452,11 @@ struct Args {
|
|||||||
/// Control the maximum number of inputs that a client can send in a single request
|
/// Control the maximum number of inputs that a client can send in a single request
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
|
||||||
|
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
|
||||||
|
/// startup that will be available to callers via the `adapter_id` field in a request.
|
||||||
|
#[clap(long, env)]
|
||||||
|
lora_adapters: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -485,6 +490,7 @@ fn shard_manager(
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
|
lora_adapters: Option<String>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
otlp_service_name: String,
|
otlp_service_name: String,
|
||||||
log_level: LevelFilter,
|
log_level: LevelFilter,
|
||||||
@ -620,6 +626,11 @@ fn shard_manager(
|
|||||||
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Lora Adapters
|
||||||
|
if let Some(lora_adapters) = lora_adapters {
|
||||||
|
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
|
||||||
|
}
|
||||||
|
|
||||||
// If huggingface_hub_cache is some, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
@ -762,7 +773,7 @@ fn num_cuda_devices() -> Option<usize> {
|
|||||||
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
|
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
|
||||||
Ok(devices) => devices,
|
Ok(devices) => devices,
|
||||||
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
|
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
let n_devices = devices.split(',').count();
|
let n_devices = devices.split(',').count();
|
||||||
Some(n_devices)
|
Some(n_devices)
|
||||||
@ -1060,6 +1071,7 @@ fn spawn_shards(
|
|||||||
let rope_scaling = args.rope_scaling;
|
let rope_scaling = args.rope_scaling;
|
||||||
let rope_factor = args.rope_factor;
|
let rope_factor = args.rope_factor;
|
||||||
let max_batch_size = args.max_batch_size;
|
let max_batch_size = args.max_batch_size;
|
||||||
|
let lora_adapters = args.lora_adapters.clone();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
@ -1085,6 +1097,7 @@ fn spawn_shards(
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
lora_adapters,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
otlp_service_name,
|
otlp_service_name,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
@ -1225,7 +1238,6 @@ fn spawn_webserver(
|
|||||||
router_args.push("--otlp-service-name".to_string());
|
router_args.push("--otlp-service-name".to_string());
|
||||||
router_args.push(otlp_service_name);
|
router_args.push(otlp_service_name);
|
||||||
|
|
||||||
|
|
||||||
// CORS origins
|
// CORS origins
|
||||||
for origin in args.cors_allow_origin.into_iter() {
|
for origin in args.cors_allow_origin.into_iter() {
|
||||||
router_args.push("--cors-allow-origin".to_string());
|
router_args.push("--cors-allow-origin".to_string());
|
||||||
|
@ -134,6 +134,8 @@ message Request {
|
|||||||
repeated uint32 blocks = 9;
|
repeated uint32 blocks = 9;
|
||||||
/// Paged attention slots
|
/// Paged attention slots
|
||||||
repeated uint32 slots = 10;
|
repeated uint32 slots = 10;
|
||||||
|
/// LORA adapter index
|
||||||
|
optional string adapter_id = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -177,6 +177,7 @@ impl Client {
|
|||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
@ -244,6 +244,7 @@ impl Health for ShardedClient {
|
|||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
|
@ -429,6 +429,7 @@ mod tests {
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
|
adapter_id: None,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
@ -351,6 +351,7 @@ impl State {
|
|||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
@ -491,6 +492,7 @@ mod tests {
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
|
adapter_id: None,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
|
use crate::infer::Infer;
|
||||||
use crate::{
|
use crate::{
|
||||||
default_parameters,
|
default_parameters,
|
||||||
server::{generate_internal, ComputeType},
|
server::{generate_internal, ComputeType},
|
||||||
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
|
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Serialize, ToSchema,
|
||||||
};
|
};
|
||||||
use axum::extract::{Extension, Path};
|
use axum::extract::{Extension, Path};
|
||||||
use axum::response::{IntoResponse, Response};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
use futures::stream::FuturesUnordered;
|
use futures::stream::FuturesUnordered;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use reqwest::header::HeaderMap;
|
|
||||||
use reqwest::StatusCode;
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
pub struct OutputChunk {
|
pub struct OutputChunk {
|
||||||
@ -64,8 +64,6 @@ pub struct MetadataServerResponse {
|
|||||||
pub extensions: Vec<String>,
|
pub extensions: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Routes
|
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
@ -76,13 +74,13 @@ pub struct MetadataServerResponse {
|
|||||||
example = json!({"error": "No response"}))
|
example = json!({"error": "No response"}))
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
pub async fn kserve_health_live() -> Json<LiveResponse> {
|
||||||
let data = LiveResponse { live: true };
|
let data = LiveResponse { live: true };
|
||||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
Json(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
get,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/v2/health/ready",
|
path = "/v2/health/ready",
|
||||||
responses(
|
responses(
|
||||||
@ -91,9 +89,9 @@ pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorRes
|
|||||||
example = json!({"error": "No response"}))
|
example = json!({"error": "No response"}))
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
pub async fn kserve_health_ready() -> Json<ReadyResponse> {
|
||||||
let data = ReadyResponse { live: true };
|
let data = ReadyResponse { live: true };
|
||||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
Json(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
@ -106,7 +104,7 @@ pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorRe
|
|||||||
example = json!({"error": "No response"}))
|
example = json!({"error": "No response"}))
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
pub async fn kerve_server_metadata() -> Json<MetadataServerResponse> {
|
||||||
let data = MetadataServerResponse {
|
let data = MetadataServerResponse {
|
||||||
name: "text-generation-inference".to_string(),
|
name: "text-generation-inference".to_string(),
|
||||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
@ -116,7 +114,7 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
|
|||||||
"metrics".to_string(),
|
"metrics".to_string(),
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
Json(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
@ -131,13 +129,30 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
|
|||||||
)]
|
)]
|
||||||
pub async fn kserve_model_metadata(
|
pub async fn kserve_model_metadata(
|
||||||
Path((model_name, model_version)): Path<(String, String)>,
|
Path((model_name, model_version)): Path<(String, String)>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Json<MetadataServerResponse> {
|
||||||
let data = MetadataServerResponse {
|
let data = MetadataServerResponse {
|
||||||
name: model_name,
|
name: model_name,
|
||||||
version: model_version,
|
version: model_version,
|
||||||
extensions: vec!["infer".to_string(), "ready".to_string()],
|
extensions: vec!["infer".to_string(), "ready".to_string()],
|
||||||
};
|
};
|
||||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
Json(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model version is ready", body = ReadyResponse),
|
||||||
|
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_model_metadata_ready(
|
||||||
|
Path((_model_name, _model_version)): Path<(String, String)>,
|
||||||
|
) -> Json<ReadyResponse> {
|
||||||
|
let data = ReadyResponse { live: true };
|
||||||
|
Json(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
@ -155,7 +170,7 @@ pub async fn kserve_model_infer(
|
|||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(payload): Json<InferenceRequest>,
|
Json(payload): Json<InferenceRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let id = payload.id.clone();
|
let id = payload.id.clone();
|
||||||
let str_inputs = payload
|
let str_inputs = payload
|
||||||
.inputs
|
.inputs
|
||||||
@ -226,22 +241,5 @@ pub async fn kserve_model_infer(
|
|||||||
outputs: output_chunks,
|
outputs: output_chunks,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((HeaderMap::new(), Json(inference_output)).into_response())
|
Ok((HeaderMap::new(), Json(inference_output)))
|
||||||
}
|
|
||||||
|
|
||||||
#[utoipa::path(
|
|
||||||
get,
|
|
||||||
tag = "Text Generation Inference",
|
|
||||||
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
|
||||||
responses(
|
|
||||||
(status = 200, description = "Model version is ready", body = ReadyResponse),
|
|
||||||
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
|
||||||
example = json!({"error": "No response"}))
|
|
||||||
)
|
|
||||||
)]
|
|
||||||
pub async fn kserve_model_metadata_ready(
|
|
||||||
Path((_model_name, _model_version)): Path<(String, String)>,
|
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
|
||||||
let data = ReadyResponse { live: true };
|
|
||||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
|
||||||
}
|
}
|
||||||
|
@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub grammar: Option<GrammarType>,
|
pub grammar: Option<GrammarType>,
|
||||||
|
|
||||||
|
/// Lora adapter id
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub adapter_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
|
adapter_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,7 +159,9 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
// Parse Huggingface hub token
|
||||||
let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
|
||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
|
@ -673,6 +673,7 @@ async fn completions(
|
|||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@ -1115,6 +1116,7 @@ async fn chat_completions(
|
|||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar,
|
grammar,
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1764,12 +1766,12 @@ pub async fn run(
|
|||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
paths(
|
paths(
|
||||||
kserve_model_infer,
|
|
||||||
kserve_health_live,
|
kserve_health_live,
|
||||||
kserve_health_ready,
|
kserve_health_ready,
|
||||||
kerve_server_metadata,
|
kerve_server_metadata,
|
||||||
kserve_model_metadata,
|
kserve_model_metadata,
|
||||||
kserve_model_metadata_ready,
|
kserve_model_metadata_ready,
|
||||||
|
kserve_model_infer,
|
||||||
),
|
),
|
||||||
components(schemas(
|
components(schemas(
|
||||||
InferenceOutput,
|
InferenceOutput,
|
||||||
|
@ -202,6 +202,7 @@ impl Validation {
|
|||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
grammar,
|
grammar,
|
||||||
|
adapter_id,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
@ -383,6 +384,7 @@ impl Validation {
|
|||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
adapter_id,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
|||||||
pub parameters: ValidParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: ValidStoppingParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
pub top_n_tokens: u32,
|
pub top_n_tokens: u32,
|
||||||
|
pub adapter_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
|
@ -4,6 +4,7 @@ include Makefile-vllm
|
|||||||
include Makefile-awq
|
include Makefile-awq
|
||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
|
include Makefile-lorax-punica
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
12
server/Makefile-lorax-punica
Normal file
12
server/Makefile-lorax-punica
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
||||||
|
|
||||||
|
build-lorax-punica:
|
||||||
|
if [ ! -d 'lorax-punica' ]; then \
|
||||||
|
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
|
||||||
|
fi
|
||||||
|
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
||||||
|
cd lorax-punica && git submodule update --init --recursive
|
||||||
|
cd lorax-punica/server/punica_kernels && python setup.py build
|
||||||
|
|
||||||
|
install-lorax-punica: build-lorax-punica
|
||||||
|
cd lorax-punica/server/punica_kernels && python setup.py install
|
@ -19,6 +19,23 @@ def gptq_marlin_gemm(
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def gptq_marlin_24_gemm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
b_meta: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
workspace: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
size_m: int,
|
||||||
|
size_n: int,
|
||||||
|
size_k: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Matrix multiplication using Marlin kernels. This is an extension of
|
||||||
|
`marlin_gemm` that supports 2:4 sparsity.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
def gptq_marlin_repack(
|
def gptq_marlin_repack(
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
perm: torch.Tensor,
|
perm: torch.Tensor,
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||||
"Marlin gemm with GPTQ compatibility");
|
"Marlin gemm with GPTQ compatibility");
|
||||||
|
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
|
||||||
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||||
"Repack GPTQ parameters for Marlin");
|
"Repack GPTQ parameters for Marlin");
|
||||||
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
||||||
|
@ -12,6 +12,13 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
int64_t size_k, bool is_k_full);
|
int64_t size_k, bool is_k_full);
|
||||||
|
|
||||||
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||||
|
torch::Tensor &b_meta,
|
||||||
|
torch::Tensor &b_scales,
|
||||||
|
torch::Tensor &workspace, int64_t num_bits,
|
||||||
|
int64_t size_m, int64_t size_n,
|
||||||
|
int64_t size_k);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits);
|
int64_t num_bits);
|
||||||
|
51
server/marlin/marlin_kernels/sparse/common/base.h
Normal file
51
server/marlin/marlin_kernels/sparse/common/base.h
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||||
|
* Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace marlin_24 {
|
||||||
|
|
||||||
|
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
|
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||||
|
// for instance as inputs to tensor core operations. Consequently, all
|
||||||
|
// corresponding index accesses must be compile-time constants, which is why we
|
||||||
|
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||||
|
// this.
|
||||||
|
template <typename T, int n>
|
||||||
|
struct Vec {
|
||||||
|
T elems[n];
|
||||||
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int M_, int N_, int K_>
|
||||||
|
struct ShapeBase {
|
||||||
|
static constexpr int M = M_, N = N_, K = K_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using I4 = Vec<int, 4>;
|
||||||
|
|
||||||
|
// Matrix fragments for tensor core instructions; their precise layout is
|
||||||
|
// documented here:
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||||
|
using FragA = Vec<half2, 4>;
|
||||||
|
using FragB = Vec<half2, 2>;
|
||||||
|
using FragM = Vec<uint, 1>;
|
||||||
|
using FragC = Vec<float, 4>;
|
||||||
|
using FragS = Vec<half2, 1>; // quantization scales
|
||||||
|
|
||||||
|
} // namespace marlin_24
|
136
server/marlin/marlin_kernels/sparse/common/mem.h
Normal file
136
server/marlin/marlin_kernels/sparse/common/mem.h
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||||
|
* Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "base.h"
|
||||||
|
|
||||||
|
namespace marlin_24 {
|
||||||
|
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||||
|
// predication to handle batchsizes that are not multiples of 16.
|
||||||
|
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
||||||
|
const void* glob_ptr,
|
||||||
|
bool pred = true,
|
||||||
|
const bool zfill = false) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||||
|
bool pred = true) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Asynchronous global->shared copy
|
||||||
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async copy fence.
|
||||||
|
__device__ inline void cp_async_fence() {
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until at most `n` async copy stages are still pending.
|
||||||
|
template <int n>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
|
// memory, directly in tensor core layout.
|
||||||
|
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
||||||
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||||
|
: "r"(smem));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
||||||
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
||||||
|
: "=r"(a[0]), "=r"(a[1])
|
||||||
|
: "r"(smem));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
|
// memory, directly in tensor core layout.
|
||||||
|
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
||||||
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||||
|
: "r"(smem));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||||
|
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
int state = -1;
|
||||||
|
do
|
||||||
|
// Guarantee that subsequent writes by this threadblock will be visible
|
||||||
|
// globally.
|
||||||
|
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||||
|
: "=r"(state)
|
||||||
|
: "l"(lock));
|
||||||
|
while (state != count);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release barrier and increment visitation count.
|
||||||
|
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||||
|
__syncthreads();
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
if (reset) {
|
||||||
|
lock[0] = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int val = 1;
|
||||||
|
// Make sure that all writes since acquiring this barrier are visible
|
||||||
|
// globally, while releasing the barrier.
|
||||||
|
asm volatile("fence.acq_rel.gpu;\n");
|
||||||
|
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||||
|
:
|
||||||
|
: "l"(lock), "r"(val));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace marlin_24
|
191
server/marlin/marlin_kernels/sparse/common/mma.h
Normal file
191
server/marlin/marlin_kernels/sparse/common/mma.h
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||||
|
* Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "base.h"
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
namespace marlin_24 {
|
||||||
|
|
||||||
|
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
|
||||||
|
// is not supported. On later versions of CUDA the version without ordered
|
||||||
|
// metadata results in the following warning:
|
||||||
|
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
|
||||||
|
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
|
||||||
|
// | reduced performance on some future architectures
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||||
|
#define MMA_SP_INST \
|
||||||
|
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
#else
|
||||||
|
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
||||||
|
// output/accumulation.
|
||||||
|
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
||||||
|
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
|
||||||
|
const int psel) {
|
||||||
|
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
|
||||||
|
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
||||||
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||||
|
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||||
|
|
||||||
|
float* c = reinterpret_cast<float*>(&frag_c);
|
||||||
|
if (psel == 0) {
|
||||||
|
asm volatile(MMA_SP_INST
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||||
|
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||||
|
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||||
|
asm volatile(MMA_SP_INST
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||||
|
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||||
|
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||||
|
} else {
|
||||||
|
asm volatile(MMA_SP_INST
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||||
|
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||||
|
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||||
|
asm volatile(MMA_SP_INST
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||||
|
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||||
|
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
|
// all cases.
|
||||||
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
|
int res;
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
|
||||||
|
float c3) {
|
||||||
|
uint2 r;
|
||||||
|
asm("{\n\t"
|
||||||
|
".reg .f16 a, b, c, d; \n\t"
|
||||||
|
"cvt.rn.f16.f32 a, %2; \n\t"
|
||||||
|
"cvt.rn.f16.f32 b, %3; \n\t"
|
||||||
|
"cvt.rn.f16.f32 c, %4; \n\t"
|
||||||
|
"cvt.rn.f16.f32 d, %5; \n\t"
|
||||||
|
"mov.b32 %0, {a, b}; \n\t"
|
||||||
|
"mov.b32 %1, {c, d}; \n\t"
|
||||||
|
"}"
|
||||||
|
: "=r"(r.x), "=r"(r.y)
|
||||||
|
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constructs destination register by taking bytes from 2 sources (based on
|
||||||
|
// mask)
|
||||||
|
template <int start_byte, int mask>
|
||||||
|
__device__ inline uint32_t prmt(uint32_t a) {
|
||||||
|
uint32_t res;
|
||||||
|
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "n"(start_byte), "n"(mask));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||||
|
// values. We mostly follow the strategy in the link below, with some small
|
||||||
|
// changes:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
__device__ inline FragB dequant_4bit(int q) {
|
||||||
|
const int LO = 0x000f000f;
|
||||||
|
const int HI = 0x00f000f0;
|
||||||
|
const int EX = 0x64006400;
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||||
|
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||||
|
// directly into `SUB` and `ADD`.
|
||||||
|
const int SUB = 0x64086408;
|
||||||
|
const int MUL = 0x2c002c00;
|
||||||
|
const int ADD = 0xd480d480;
|
||||||
|
|
||||||
|
FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||||
|
// values. We mostly follow the strategy in the link below, with some small
|
||||||
|
// changes:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
__device__ inline FragB dequant_8bit(int q) {
|
||||||
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||||
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||||
|
|
||||||
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||||
|
|
||||||
|
FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
|
// only for grouped quantization.
|
||||||
|
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||||
|
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
||||||
|
frag_b[0] = __hmul2(frag_b[0], s);
|
||||||
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
|
||||||
|
FragS& s0, float* c4, float* c5, float* c6,
|
||||||
|
float* c7, FragS& s1) {
|
||||||
|
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
||||||
|
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
||||||
|
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
||||||
|
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
|
||||||
|
|
||||||
|
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
|
||||||
|
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
|
||||||
|
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
|
||||||
|
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace marlin_24
|
1125
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
Normal file
1125
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -12,6 +12,7 @@ setup(
|
|||||||
"marlin_kernels/gptq_marlin.cu",
|
"marlin_kernels/gptq_marlin.cu",
|
||||||
"marlin_kernels/gptq_marlin_repack.cu",
|
"marlin_kernels/gptq_marlin_repack.cu",
|
||||||
"marlin_kernels/marlin_cuda_kernel.cu",
|
"marlin_kernels/marlin_cuda_kernel.cu",
|
||||||
|
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu",
|
||||||
"marlin_kernels/ext.cpp",
|
"marlin_kernels/ext.cpp",
|
||||||
],
|
],
|
||||||
extra_compile_args=extra_compile_args,
|
extra_compile_args=extra_compile_args,
|
||||||
|
@ -17,7 +17,12 @@ def get_test_model():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
|
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
|
||||||
|
|
||||||
model = TestModel(
|
model = TestModel(
|
||||||
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
|
"test_model_id",
|
||||||
|
torch.nn.Linear(1, 1),
|
||||||
|
tokenizer,
|
||||||
|
False,
|
||||||
|
torch.float32,
|
||||||
|
torch.device("cpu"),
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
1152
server/tests/utils/test_weights.py
Normal file
1152
server/tests/utils/test_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
13
server/text_generation_server/adapters/__init__.py
Normal file
13
server/text_generation_server/adapters/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/__init__.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchData,
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdapterBatchData",
|
||||||
|
"AdapterBatchMetadata",
|
||||||
|
]
|
44
server/text_generation_server/adapters/config.py
Normal file
44
server/text_generation_server/adapters/config.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/config.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import AdapterWeights
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModuleMap:
|
||||||
|
module_name: str
|
||||||
|
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterConfig(ABC):
|
||||||
|
base_model_name_or_path: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_batched_adapter_weights(
|
||||||
|
self,
|
||||||
|
model: "Model",
|
||||||
|
module_map: ModuleMap,
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
dynamic: bool,
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
pass
|
482
server/text_generation_server/adapters/lora.py
Normal file
482
server/text_generation_server/adapters/lora.py
Normal file
@ -0,0 +1,482 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/lora.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig as _LoraConfig
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
AdapterWeights,
|
||||||
|
BatchAdapterWeights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
BGMV_MAX_RANK,
|
||||||
|
MAX_RANK_CUSTOM,
|
||||||
|
get_tmp_tensors,
|
||||||
|
orient_for_rank,
|
||||||
|
pad_rank,
|
||||||
|
use_cutlass_shrink,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||||
|
block_size = size // world_size
|
||||||
|
start = offset + rank * block_size
|
||||||
|
stop = offset + (rank + 1) * block_size
|
||||||
|
return start, stop
|
||||||
|
|
||||||
|
|
||||||
|
def shard_on_dim(
|
||||||
|
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||||
|
):
|
||||||
|
world_size = process_group.size()
|
||||||
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
size = t.shape[dim]
|
||||||
|
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||||
|
|
||||||
|
if dim == 0:
|
||||||
|
tensor = t[start:stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = t[:, start:stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def shard_lora_weights(
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
split_dim: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
# [hidden_size, r]
|
||||||
|
weights_a = [
|
||||||
|
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
||||||
|
]
|
||||||
|
|
||||||
|
# [r, hidden_size]
|
||||||
|
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
||||||
|
|
||||||
|
return weights_a, weights_b
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraConfig(AdapterConfig):
|
||||||
|
r: int
|
||||||
|
target_modules: Optional[Union[List[str], str]]
|
||||||
|
fan_in_fan_out: bool
|
||||||
|
lora_alpha: int
|
||||||
|
use_rslora: bool
|
||||||
|
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
adapter_weight_names = set()
|
||||||
|
module_map = {}
|
||||||
|
for weight_name in weight_names:
|
||||||
|
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
||||||
|
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
||||||
|
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_map[weight_name] = {
|
||||||
|
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
||||||
|
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
||||||
|
}
|
||||||
|
adapter_weight_names.add(lora_a_name)
|
||||||
|
adapter_weight_names.add(lora_b_name)
|
||||||
|
return module_map, adapter_weight_names
|
||||||
|
|
||||||
|
def load_batched_adapter_weights(
|
||||||
|
self,
|
||||||
|
model: "Model",
|
||||||
|
module_map: Dict[str, Dict],
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
dynamic: bool,
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
return LoraWeights.load(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
module_map,
|
||||||
|
layer_type,
|
||||||
|
unused_weight_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||||
|
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||||
|
return cls(
|
||||||
|
base_model_name_or_path=hf_config.base_model_name_or_path,
|
||||||
|
r=hf_config.r,
|
||||||
|
target_modules=hf_config.target_modules,
|
||||||
|
fan_in_fan_out=hf_config.fan_in_fan_out,
|
||||||
|
lora_alpha=hf_config.lora_alpha,
|
||||||
|
use_rslora=(
|
||||||
|
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraWeights(AdapterWeights):
|
||||||
|
"""LoRA weights for a single adapter merged across all layers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
adapter_config: LoraConfig,
|
||||||
|
):
|
||||||
|
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||||
|
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||||
|
|
||||||
|
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||||
|
self._is_transposed = False
|
||||||
|
|
||||||
|
# [num_layers, hidden_size, r]
|
||||||
|
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||||
|
self._weights_a = torch.stack(weights_a)
|
||||||
|
|
||||||
|
# [num_layers, r, hidden_size]
|
||||||
|
self._weights_b = torch.stack(weights_b)
|
||||||
|
|
||||||
|
self.adapter_config = adapter_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
def _transpose_weights(self):
|
||||||
|
if self._use_cutlass_shrink:
|
||||||
|
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
||||||
|
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
||||||
|
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
||||||
|
self._is_transposed = not self._is_transposed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||||
|
return [BatchLoraWeights]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
config: LoraConfig,
|
||||||
|
model: "Model",
|
||||||
|
module_map: Dict[str, Dict],
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
nlayers = model.get_num_layers_for_type(layer_type)
|
||||||
|
lora_a_list = [None] * nlayers
|
||||||
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
|
for layer_id in range(nlayers):
|
||||||
|
key = (layer_id, layer_type)
|
||||||
|
weight_name, layer = model.target_to_layer[key]
|
||||||
|
base_weight = layer.base_layer.linear.weight
|
||||||
|
base_device = base_weight.device
|
||||||
|
|
||||||
|
if weight_name not in module_map:
|
||||||
|
# There is no LoRA weight for this layer type in the adapter
|
||||||
|
return None
|
||||||
|
|
||||||
|
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||||
|
lora_a = lora_a.to(base_device, model.dtype)
|
||||||
|
|
||||||
|
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||||
|
lora_b = lora_b.to(base_device, model.dtype)
|
||||||
|
|
||||||
|
scale = get_scaling_factor(
|
||||||
|
config.lora_alpha,
|
||||||
|
config.r,
|
||||||
|
uses_rslora=config.use_rslora,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names.discard(lora_a_name)
|
||||||
|
unused_weight_names.discard(lora_b_name)
|
||||||
|
|
||||||
|
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
||||||
|
# (A * B) * C = A * (B * C)
|
||||||
|
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
||||||
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||||
|
|
||||||
|
# pad lora ranks to be compatible with sgmv
|
||||||
|
lora_a_list = [
|
||||||
|
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
||||||
|
]
|
||||||
|
lora_b_list = [
|
||||||
|
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if lora_a_list:
|
||||||
|
# update rank if it was padded
|
||||||
|
padded_rank = lora_a_list[0].size(1)
|
||||||
|
config.r = padded_rank
|
||||||
|
|
||||||
|
return LoraWeights(
|
||||||
|
*shard_lora_weights(
|
||||||
|
weights_a=lora_a_list,
|
||||||
|
weights_b=lora_b_list,
|
||||||
|
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
||||||
|
process_group=model.process_group,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RankSegments:
|
||||||
|
rank: int
|
||||||
|
|
||||||
|
lora_a_ptr: torch.Tensor
|
||||||
|
lora_b_ptr: torch.Tensor
|
||||||
|
|
||||||
|
# prefill (sgmv)
|
||||||
|
tmp_shrink: torch.Tensor
|
||||||
|
tmp_expand: torch.Tensor
|
||||||
|
segment_starts: torch.Tensor
|
||||||
|
segment_ends: torch.Tensor
|
||||||
|
|
||||||
|
# decode (bgmv)
|
||||||
|
indices: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchLoraWeights(BatchAdapterWeights):
|
||||||
|
lora_a: Dict[int, torch.Tensor]
|
||||||
|
lora_b: Dict[int, torch.Tensor]
|
||||||
|
adapter_index_configs: Dict[int, LoraConfig]
|
||||||
|
rank_data: Dict[int, RankSegments]
|
||||||
|
use_sgmv: bool
|
||||||
|
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
return adapter_index in self.adapter_index_configs
|
||||||
|
|
||||||
|
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||||
|
return all(
|
||||||
|
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||||
|
for rank_data in self.rank_data.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def key(cls) -> str:
|
||||||
|
return "lora"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Optional["BatchLoraWeights"]:
|
||||||
|
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||||
|
adapter_weights = {
|
||||||
|
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||||
|
}
|
||||||
|
if not adapter_weights:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_weights = next(iter(adapter_weights.values()))
|
||||||
|
device = first_weights.weights_a.device
|
||||||
|
segment_indices = meta.segment_indices
|
||||||
|
|
||||||
|
lora_a = {
|
||||||
|
idx: adapter_weights[idx].weights_a
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
lora_b = {
|
||||||
|
idx: adapter_weights[idx].weights_b
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
max_rank = max(
|
||||||
|
(
|
||||||
|
adapter_weights[idx].lora_a_r
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
),
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefill or max_rank > BGMV_MAX_RANK:
|
||||||
|
use_sgmv = True
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
use_sgmv = False
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index_configs = {
|
||||||
|
idx: adapter_weights[idx].adapter_config
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||||
|
|
||||||
|
rank_indices = defaultdict(list)
|
||||||
|
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||||
|
if adapter_idx not in adapter_weights:
|
||||||
|
continue
|
||||||
|
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||||
|
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||||
|
for head_index in prefill_head_indices:
|
||||||
|
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||||
|
if head_index < meta.adapter_segments[j]:
|
||||||
|
prefill_head_segment_ends[-1] += 1
|
||||||
|
else:
|
||||||
|
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||||
|
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
rank_data = {}
|
||||||
|
for rank, indices in rank_indices.items():
|
||||||
|
tmp_shrink = None
|
||||||
|
tmp_expand = None
|
||||||
|
segment_starts = None
|
||||||
|
segment_ends = None
|
||||||
|
batch_indices = None
|
||||||
|
|
||||||
|
if use_sgmv:
|
||||||
|
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||||
|
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||||
|
lora_a_ptr_indices.size(0), rank, device
|
||||||
|
)
|
||||||
|
segment_starts = meta.adapter_segments[indices]
|
||||||
|
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
for i, segment_index in enumerate(indices):
|
||||||
|
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
||||||
|
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
||||||
|
else:
|
||||||
|
rank_indices = set(indices)
|
||||||
|
batch_indices = [
|
||||||
|
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||||
|
]
|
||||||
|
batch_indices = [
|
||||||
|
idx if idx in rank_indices else -1 for idx in batch_indices
|
||||||
|
]
|
||||||
|
batch_indices = torch.tensor(
|
||||||
|
batch_indices, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
rank_data[rank] = RankSegments(
|
||||||
|
rank=rank,
|
||||||
|
tmp_shrink=tmp_shrink,
|
||||||
|
tmp_expand=tmp_expand,
|
||||||
|
lora_a_ptr=lora_a_ptr[indices],
|
||||||
|
lora_b_ptr=lora_b_ptr[indices],
|
||||||
|
segment_starts=segment_starts,
|
||||||
|
segment_ends=segment_ends,
|
||||||
|
indices=batch_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatchLoraWeights(
|
||||||
|
lora_a=lora_a,
|
||||||
|
lora_b=lora_b,
|
||||||
|
adapter_index_configs=adapter_index_configs,
|
||||||
|
rank_data=rank_data,
|
||||||
|
use_sgmv=use_sgmv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scaling_factor(
|
||||||
|
lora_alpha: int,
|
||||||
|
r: int,
|
||||||
|
uses_rslora: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Computes the scaling factor for the lora weights."""
|
||||||
|
if uses_rslora:
|
||||||
|
return lora_alpha / (r**0.5)
|
||||||
|
return lora_alpha / r
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
||||||
|
if hasattr(v, "lora_weights"):
|
||||||
|
return v.lora_weights
|
||||||
|
return v
|
158
server/text_generation_server/adapters/weights.py
Normal file
158
server/text_generation_server/adapters/weights.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/weights.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from abc import ABC, abstractclassmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Set, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchMetadata:
|
||||||
|
# [batch_size]
|
||||||
|
adapter_indices: torch.Tensor
|
||||||
|
|
||||||
|
# [num_adapters]
|
||||||
|
adapter_set: Set[int]
|
||||||
|
|
||||||
|
# [num_segments + 1]
|
||||||
|
adapter_segments: torch.Tensor
|
||||||
|
|
||||||
|
# [num_segments]
|
||||||
|
# maps from segment index to adapter index, i.e.:
|
||||||
|
# segment_indices[s] == adapter_indices[i]
|
||||||
|
segment_indices: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speculative_tokens(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class BatchAdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractclassmethod
|
||||||
|
def key(cls) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractclassmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: "AdapterBatchMetadata",
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: torch.Tensor,
|
||||||
|
) -> Optional["BatchAdapterWeights"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LayerAdapterWeights:
|
||||||
|
"""Adapter weights that apply to a particular layer."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
||||||
|
|
||||||
|
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
||||||
|
self.adapter_weights[adapter_idx] = weights
|
||||||
|
|
||||||
|
def remove_adapter(self, adapter_idx: int):
|
||||||
|
if adapter_idx not in self.adapter_weights:
|
||||||
|
return
|
||||||
|
del self.adapter_weights[adapter_idx]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_speculative_tokens(self) -> int:
|
||||||
|
return max(
|
||||||
|
adapter_weights.speculative_tokens
|
||||||
|
for adapter_weights in self.adapter_weights.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return len(self.adapter_weights) == 0
|
||||||
|
|
||||||
|
def get_data(
|
||||||
|
self,
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Dict[str, BatchAdapterWeights]:
|
||||||
|
# bucket adapters by batch class
|
||||||
|
adapter_batch_types: Dict[
|
||||||
|
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
||||||
|
] = defaultdict(dict)
|
||||||
|
for adapter_index, adapter_weights in self.adapter_weights.items():
|
||||||
|
for batch_type in adapter_weights.get_batch_types():
|
||||||
|
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
||||||
|
|
||||||
|
batch_data = {}
|
||||||
|
for batch_type, adapter_weights in adapter_batch_types.items():
|
||||||
|
batched_weights = batch_type.load(
|
||||||
|
adapter_weights, meta, prefill, prefill_head_indices
|
||||||
|
)
|
||||||
|
if batched_weights is not None:
|
||||||
|
batch_data[batch_type.key()] = batched_weights
|
||||||
|
return batch_data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchData:
|
||||||
|
meta: AdapterBatchMetadata
|
||||||
|
|
||||||
|
# layer type -> adapter type -> batch weight data
|
||||||
|
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
||||||
|
|
||||||
|
prefill: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_meta(
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
weights: Dict[str, LayerAdapterWeights],
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> "AdapterBatchData":
|
||||||
|
data = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
if v.is_empty():
|
||||||
|
continue
|
||||||
|
data[k] = v.get_data(
|
||||||
|
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
||||||
|
)
|
||||||
|
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||||
|
|
||||||
|
def ranks(self) -> Set[int]:
|
||||||
|
# TODO(travis): refactor to be less coupled to lora implementation
|
||||||
|
ranks = set()
|
||||||
|
for layer_data in self.data.values():
|
||||||
|
lora_data = layer_data.get("lora")
|
||||||
|
if lora_data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for rank_data in lora_data.rank_data.values():
|
||||||
|
ranks.add(rank_data.rank)
|
||||||
|
|
||||||
|
return ranks
|
||||||
|
|
||||||
|
def layer_names(self) -> Set[str]:
|
||||||
|
return set(self.data.keys())
|
||||||
|
|
||||||
|
def adapter_keys(self) -> Set[str]:
|
||||||
|
adapter_keys = set()
|
||||||
|
for layer_data in self.data.values():
|
||||||
|
adapter_keys.update(layer_data.keys())
|
||||||
|
return adapter_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_rank(self) -> int:
|
||||||
|
ranks = self.ranks()
|
||||||
|
return max(ranks) if len(ranks) > 0 else 0
|
@ -79,6 +79,18 @@ def serve(
|
|||||||
if otlp_endpoint is not None:
|
if otlp_endpoint is not None:
|
||||||
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
|
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||||
|
|
||||||
|
# split on comma and strip whitespace
|
||||||
|
lora_adapter_ids = (
|
||||||
|
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(lora_adapter_ids) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||||
|
)
|
||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = None if dtype is None else dtype.value
|
dtype = None if dtype is None else dtype.value
|
||||||
@ -93,6 +105,7 @@ def serve(
|
|||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id,
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
@ -113,6 +126,7 @@ def download_weights(
|
|||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
merge_lora: bool = False,
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -143,18 +157,28 @@ def download_weights(
|
|||||||
) is not None
|
) is not None
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
try:
|
# TODO: maybe reverse the default value of merge_lora?
|
||||||
adapter_config_filename = hf_hub_download(
|
# currently by default we don't merge the weights with the base model
|
||||||
model_id, revision=revision, filename="adapter_config.json"
|
if merge_lora:
|
||||||
)
|
try:
|
||||||
utils.download_and_unload_peft(
|
adapter_config_filename = hf_hub_download(
|
||||||
model_id, revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, filename="adapter_config.json"
|
||||||
)
|
)
|
||||||
is_local_model = True
|
utils.download_and_unload_peft(
|
||||||
utils.weight_files(model_id, revision, extension)
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
return
|
)
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
is_local_model = True
|
||||||
pass
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
utils.peft.download_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
|||||||
# Just to add the `load` methods.
|
# Just to add the `load` methods.
|
||||||
from text_generation_server.layers.layernorm import load_layer_norm
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
from text_generation_server.layers.conv import load_conv2d
|
from text_generation_server.layers.conv import load_conv2d
|
||||||
|
|
||||||
|
from text_generation_server.layers.lora import (
|
||||||
|
LoraLinear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
@ -7,7 +7,7 @@ if SYSTEM == "cuda":
|
|||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
@ -56,8 +57,6 @@ def paged_attention(
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
query = query.contiguous()
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
@ -67,7 +66,7 @@ def paged_attention(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
@ -1,6 +1,7 @@
|
|||||||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import awq_inference_engine # with CUDA kernels
|
import awq_inference_engine # with CUDA kernels
|
||||||
@ -17,7 +18,9 @@ import awq_inference_engine # with CUDA kernels
|
|||||||
|
|
||||||
|
|
||||||
class WQLinear(nn.Module):
|
class WQLinear(nn.Module):
|
||||||
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
def __init__(
|
||||||
|
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if w_bit not in [4]:
|
if w_bit not in [4]:
|
||||||
@ -35,10 +38,7 @@ class WQLinear(nn.Module):
|
|||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.qzeros = qzeros
|
self.qzeros = qzeros
|
||||||
self.scales = scales
|
self.scales = scales
|
||||||
if bias:
|
self.bias = bias
|
||||||
self.bias = bias
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -82,18 +82,20 @@ elif SYSTEM == "rocm":
|
|||||||
|
|
||||||
return super().forward(hidden_states), residual
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
res_out = hidden_states
|
|
||||||
out = ipex.llm.functional.add_layer_norm(
|
out = ipex.llm.functional.add_layer_norm(
|
||||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
residual,
|
||||||
|
hidden_states,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.eps,
|
||||||
|
residual is not None,
|
||||||
)
|
)
|
||||||
if residual is not None:
|
return out, residual if residual is not None else hidden_states
|
||||||
res_out = residual
|
|
||||||
return out, res_out
|
|
||||||
|
|
||||||
|
|
||||||
class FastRMSNorm(nn.Module):
|
class FastRMSNorm(nn.Module):
|
||||||
@ -109,19 +111,16 @@ class FastRMSNorm(nn.Module):
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if SYSTEM == "xpu":
|
if SYSTEM == "ipex":
|
||||||
residual_out = hidden_states
|
|
||||||
out = ipex.llm.functional.add_rms_norm(
|
out = ipex.llm.functional.add_rms_norm(
|
||||||
residual,
|
residual,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight,
|
self.weight,
|
||||||
None,
|
None,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
True,
|
residual is not None,
|
||||||
)
|
)
|
||||||
if residual is not None:
|
return out, residual if residual is not None else hidden_states
|
||||||
residual_out = residual
|
|
||||||
return out, residual_out
|
|
||||||
elif hidden_states.shape[-1] > 8192:
|
elif hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinLinear
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
@ -217,7 +216,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
qweight=weight.qweight,
|
qweight=weight.qweight,
|
||||||
qzeros=weight.qzeros,
|
qzeros=weight.qzeros,
|
||||||
scales=weight.scales,
|
scales=weight.scales,
|
||||||
bias=bias is not None,
|
bias=bias,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -225,6 +224,9 @@ def get_linear(weight, bias, quantize):
|
|||||||
)
|
)
|
||||||
elif quantize == "marlin":
|
elif quantize == "marlin":
|
||||||
from text_generation_server.layers.marlin import (
|
from text_generation_server.layers.marlin import (
|
||||||
|
GPTQMarlin24Linear,
|
||||||
|
GPTQMarlin24Weight,
|
||||||
|
GPTQMarlinLinear,
|
||||||
GPTQMarlinWeight,
|
GPTQMarlinWeight,
|
||||||
MarlinLinear,
|
MarlinLinear,
|
||||||
MarlinWeight,
|
MarlinWeight,
|
||||||
@ -235,6 +237,11 @@ def get_linear(weight, bias, quantize):
|
|||||||
weight=weight,
|
weight=weight,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
elif isinstance(weight, GPTQMarlin24Weight):
|
||||||
|
linear = GPTQMarlin24Linear(
|
||||||
|
weight=weight,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
elif isinstance(weight, MarlinWeight):
|
elif isinstance(weight, MarlinWeight):
|
||||||
linear = MarlinLinear(weight=weight, bias=bias)
|
linear = MarlinLinear(weight=weight, bias=bias)
|
||||||
else:
|
else:
|
||||||
|
286
server/text_generation_server/layers/lora.py
Normal file
286
server/text_generation_server/layers/lora.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
add_lora_a_bgmv,
|
||||||
|
add_lora_b_bgmv,
|
||||||
|
has_sgmv,
|
||||||
|
lora_a_sgmv_cutlass,
|
||||||
|
lora_b_sgmv_cutlass,
|
||||||
|
orient_for_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
|
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward_layer_type(
|
||||||
|
self,
|
||||||
|
result: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
adapter_data: "AdapterBatchData",
|
||||||
|
layer_type: str,
|
||||||
|
start_idx: int,
|
||||||
|
end_idx: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if adapter_data is None:
|
||||||
|
return result
|
||||||
|
data = adapter_data.data.get(layer_type)
|
||||||
|
data: Optional["BatchLoraWeights"] = (
|
||||||
|
data.get("lora") if data is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
|
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||||
|
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||||
|
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||||
|
# This approach ensures accurate LoRA application across various layer sizes and
|
||||||
|
# configurations, adapting to different model architectures and parallelization strategies.
|
||||||
|
#
|
||||||
|
# Example scenarios where this is necessary:
|
||||||
|
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||||
|
# 2. We're processing the last segment which might be smaller.
|
||||||
|
# 3. Different projection layers (q, k, v) have different sizes.
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||||
|
else:
|
||||||
|
proj = result
|
||||||
|
|
||||||
|
for r, rank_segments in data.rank_data.items():
|
||||||
|
lora_a_ptr = rank_segments.lora_a_ptr
|
||||||
|
lora_b_ptr = rank_segments.lora_b_ptr
|
||||||
|
|
||||||
|
if lora_a_ptr is None or lora_b_ptr is None:
|
||||||
|
raise ValueError("LoRA data is missing")
|
||||||
|
|
||||||
|
if data.use_sgmv:
|
||||||
|
# Use SGMV for prefill
|
||||||
|
v = lora_a_sgmv_cutlass(
|
||||||
|
input,
|
||||||
|
rank_segments.tmp_shrink,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
lora_b_sgmv_cutlass(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
rank_segments.tmp_expand,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use BGMV for decode
|
||||||
|
v = torch.zeros(
|
||||||
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||||
|
)
|
||||||
|
# TODO: error with [-1, 0], but not [0, -1]
|
||||||
|
add_lora_a_bgmv(
|
||||||
|
v,
|
||||||
|
input,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
add_lora_b_bgmv(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
result[:, start_idx:end_idx] += proj
|
||||||
|
else:
|
||||||
|
for adapter_index in adapter_data.meta.adapter_set:
|
||||||
|
if data is not None and data.has_adapter(adapter_index):
|
||||||
|
adapter_mask = (
|
||||||
|
(adapter_data.meta.adapter_indices == adapter_index)
|
||||||
|
.to(input.dtype)
|
||||||
|
.view(-1, 1)
|
||||||
|
)
|
||||||
|
layer_result = self.forward_lora(
|
||||||
|
input, data, adapter_index, adapter_mask
|
||||||
|
)
|
||||||
|
result[:, start_idx:end_idx] += layer_result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward_lora(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
data: "BatchLoraWeights",
|
||||||
|
adapter_index: int,
|
||||||
|
adapter_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||||
|
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||||
|
|
||||||
|
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||||
|
|
||||||
|
a_out = input @ lora_a
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
a_out = self.collect_lora_a(a_out)
|
||||||
|
|
||||||
|
result = (a_out @ lora_b) * adapter_mask
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("Implemented in subclasses")
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_names = layer_names
|
||||||
|
self.sizes = sizes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
return TensorParallelMultiAdapterLinear(
|
||||||
|
base_layer, layer_id, layer_names, sizes, process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
# noop if no layer names are provided (e.g. for models without adapters)
|
||||||
|
if self.layer_names is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# handle models like Bloom that have inputs of shape
|
||||||
|
# (batch_size, sequence_length, hidden_size)
|
||||||
|
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||||
|
# for the LoRA computation, then reshape back
|
||||||
|
prev_shape = result.shape
|
||||||
|
is_3d = len(input.shape) >= 3
|
||||||
|
if is_3d:
|
||||||
|
input = input.reshape(-1, input.shape[-1])
|
||||||
|
result = result.reshape(-1, result.shape[-1])
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, layer_name in enumerate(self.layer_names):
|
||||||
|
start_idx = offset // self.process_group.size()
|
||||||
|
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||||
|
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||||
|
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||||
|
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||||
|
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||||
|
# different projection sizes across layers and model architectures.
|
||||||
|
if self.sizes is not None:
|
||||||
|
offset += self.sizes[i]
|
||||||
|
end_idx = offset // self.process_group.size()
|
||||||
|
else:
|
||||||
|
end_idx = result.shape[1]
|
||||||
|
|
||||||
|
result = self.forward_layer_type(
|
||||||
|
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_3d:
|
||||||
|
result = result.reshape(prev_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||||
|
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
gathered_tensors = [
|
||||||
|
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||||
|
return torch.cat(gathered_tensors, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||||
|
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||||
|
return cls(base_layer, layer_id, layer_name, process_group)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
if self.layer_name is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||||
|
stride = result.shape[-1] // self.process_group.size()
|
||||||
|
start_idx = self.process_group.rank() * stride
|
||||||
|
end_idx = (self.process_group.rank() + 1) * stride
|
||||||
|
|
||||||
|
self.forward_layer_type(
|
||||||
|
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||||
|
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||||
|
return a_out
|
@ -1,9 +1,8 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, List
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -177,12 +176,12 @@ class GPTQMarlinLinear(nn.Module):
|
|||||||
self.bits = weight.bits
|
self.bits = weight.bits
|
||||||
self.is_full_k = weight.is_full_k
|
self.is_full_k = weight.is_full_k
|
||||||
|
|
||||||
self.register_buffer("qweight", weight.qweight)
|
self.qweight = weight.qweight
|
||||||
self.register_buffer("scales", weight.scales)
|
self.scales = weight.scales
|
||||||
self.register_buffer("g_idx", weight.g_idx)
|
self.g_idx = weight.g_idx
|
||||||
self.register_buffer("perm", weight.perm)
|
self.perm = weight.perm
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.register_buffer("bias", bias)
|
self.bias = bias
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
@ -215,6 +214,116 @@ class GPTQMarlinLinear(nn.Module):
|
|||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||||
|
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||||
|
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||||
|
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||||
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTQMarlin24Weight:
|
||||||
|
"""
|
||||||
|
GPTQ-Marlin 2:4 weights.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||||
|
B_meta (torch.Tensor): metadata for 2:4 sparsity.
|
||||||
|
s (torch.Tensor): float16 scales.
|
||||||
|
bits: quantized weight size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
B: torch.Tensor
|
||||||
|
B_meta: torch.Tensor
|
||||||
|
s: torch.Tensor
|
||||||
|
bits: int
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.B.dtype == torch.int32
|
||||||
|
assert self.B_meta.dtype == torch.int16
|
||||||
|
assert self.s.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlin24Linear(nn.Module):
|
||||||
|
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
_check_marlin_kernels()
|
||||||
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
if weight.bits not in GPTQ_MARLIN_BITS:
|
||||||
|
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
||||||
|
)
|
||||||
|
|
||||||
|
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
|
||||||
|
out_features = weight.s.shape[1]
|
||||||
|
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
||||||
|
|
||||||
|
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||||
|
supported_sizes = ", ".join(
|
||||||
|
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bits = weight.bits
|
||||||
|
weights_per_int32 = 32 // self.bits
|
||||||
|
|
||||||
|
assert (
|
||||||
|
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
|
||||||
|
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
|
||||||
|
assert (
|
||||||
|
out_features % weights_per_int32 == 0
|
||||||
|
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
|
||||||
|
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
|
||||||
|
if groupsize != -1 and in_features % groupsize != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.B = weight.B
|
||||||
|
self.B_meta = weight.B_meta
|
||||||
|
self.s = weight.s
|
||||||
|
if bias is not None:
|
||||||
|
self.bias = bias
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.workspace = torch.zeros(
|
||||||
|
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=weight.B.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
C = marlin_kernels.gptq_marlin_24_gemm(
|
||||||
|
A.view(-1, A.shape[-1]),
|
||||||
|
self.B,
|
||||||
|
self.B_meta,
|
||||||
|
self.s,
|
||||||
|
self.workspace,
|
||||||
|
self.bits,
|
||||||
|
A.shape[0],
|
||||||
|
self.s.shape[1],
|
||||||
|
A.shape[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
C += self.bias
|
||||||
|
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MarlinWeight:
|
class MarlinWeight:
|
||||||
"""
|
"""
|
||||||
@ -262,10 +371,10 @@ class MarlinLinear(nn.Module):
|
|||||||
128,
|
128,
|
||||||
}, f"Group size must be -1 or 128, was {groupsize}"
|
}, f"Group size must be -1 or 128, was {groupsize}"
|
||||||
|
|
||||||
self.register_buffer("B", weight.B)
|
self.B = weight.B
|
||||||
self.register_buffer("s", weight.s)
|
self.s = weight.s
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.register_buffer("bias", bias)
|
self.bias = bias
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ if SYSTEM == "cuda":
|
|||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
ipex.llm.functional.rotary_embedding(
|
ipex.llm.functional.rotary_embedding(
|
||||||
query, key, sin, cos, query.size(-1), True
|
query, key, sin, cos, query.size(-1), True
|
||||||
)
|
)
|
||||||
|
@ -3,6 +3,10 @@ from torch.nn import functional as F
|
|||||||
from typing import Iterable, List
|
from typing import Iterable, List
|
||||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
class LayerConcat(torch.nn.Module):
|
class LayerConcat(torch.nn.Module):
|
||||||
@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer):
|
|||||||
local_out = gather_input.T
|
local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
if SYSTEM == "ipex":
|
||||||
torch.distributed.all_gather_into_tensor(
|
ipex.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, gather_input, group=self.process_group
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
if input.shape[0] == 1:
|
||||||
return world_out
|
return world_out
|
||||||
@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer):
|
|||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
if SYSTEM == "ipex":
|
||||||
|
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
world_output = torch.cat(world_output, dim=-1)
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
return world_output
|
return world_output
|
||||||
|
|
||||||
@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if SYSTEM == "ipex":
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if SYSTEM == "ipex":
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
@ -255,6 +255,7 @@ for data in ModelType:
|
|||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
lora_adapter_ids: Optional[List[str]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
@ -602,6 +603,7 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||||
|
@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -538,6 +538,7 @@ class CausalLM(Model):
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -22,7 +22,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
if SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
input_embeds = self.embed_tokens(input_ids)
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
token_embeds = self.embed_tokens(input_ids)
|
token_embeds = self.embed_tokens(input_ids)
|
||||||
position_embeds = self.embed_positions(position_ids)
|
position_embeds = self.embed_positions(position_ids)
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -37,6 +38,8 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
@ -50,43 +53,61 @@ if SYSTEM == "rocm":
|
|||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
bias = getattr(config, "attention_bias", False)
|
bias = getattr(config, "attention_bias", False)
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = None
|
||||||
|
prefixes = None
|
||||||
|
|
||||||
# if specific model type, load the correct attention
|
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
prefix = f"{prefix}.qkv_proj"
|
||||||
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
elif config.model_type == "baichuan":
|
elif config.model_type == "baichuan":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
prefix = f"{prefix}.W_pack"
|
||||||
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.W_pack",
|
prefix=prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
# otherwise, load the default attention based on the number of heads
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
return TensorParallelColumnLinear.load_multi(
|
base_layer=base_layer,
|
||||||
config,
|
layer_id=layer_id,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
layer_names=prefixes,
|
||||||
dim=0,
|
sizes=sizes,
|
||||||
weights=weights,
|
process_group=weights.process_group,
|
||||||
bias=bias,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
index: int,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -120,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
|
self.index = index
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -144,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
@ -188,11 +219,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, index):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -207,29 +241,54 @@ class LlamaMLP(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
prefixes = None
|
||||||
|
sizes = None
|
||||||
|
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
bias = getattr(config, "mlp_bias", False)
|
bias = getattr(config, "mlp_bias", False)
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
prefixes = [f"gate_proj", f"up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
|
||||||
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
index,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
@ -237,7 +296,7 @@ class LlamaMLP(nn.Module):
|
|||||||
# TODO: This is a hotfix to be removed & properly refactored.
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
self.quantize = config.quantize
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
@ -251,20 +310,27 @@ class LlamaMLP(nn.Module):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
return self.down_proj(out)
|
return self.down_proj(out, adapter_data)
|
||||||
else:
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashLlamaAttention(
|
self.self_attn = FlashLlamaAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
index=index,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.mlp = LlamaMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
@ -287,6 +353,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -301,6 +368,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -308,7 +376,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
@ -323,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{layer_id}"
|
f"model.layers.{layer_id}"
|
||||||
if not prefix
|
if not prefix
|
||||||
@ -358,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -380,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -421,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
@ -434,9 +506,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
|
|||||||
|
|
||||||
|
|
||||||
class MistralAttention(torch.nn.Module):
|
class MistralAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
self,
|
|
||||||
prefix: str,
|
|
||||||
config,
|
|
||||||
weights,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = (
|
||||||
config.sliding_window if config.sliding_window is not None else -1
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
dim=0,
|
dim=0,
|
||||||
@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
|
query_key_value,
|
||||||
|
layer_id,
|
||||||
|
["q_proj", "k_proj", "v_proj"],
|
||||||
|
sizes=[
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
layer_id,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MistralMLP(nn.Module):
|
class MistralMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
layer_id,
|
||||||
|
["gate_proj", "up_proj"],
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
layer_id,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
|
|||||||
# TODO: This is a hotfix to be removed & properly refactored.
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
self.quantize = config.quantize
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
return self.down_proj(out)
|
return self.down_proj(out, adapter_data)
|
||||||
else:
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MistralLayer(nn.Module):
|
class MistralLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = MistralAttention(
|
self.self_attn = MistralAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
self.mlp = MistralMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||||
)
|
)
|
||||||
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
|
|||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
|
|||||||
prefix=f"{prefix}.layers.{layer_id}",
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
true_max_s = max_s
|
true_max_s = max_s
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -26,7 +26,7 @@ import numpy as np
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
if SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
true_max_s = max_s
|
true_max_s = max_s
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
|
@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.gpt_neox(
|
hidden_states = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
# Unused here
|
# Unused here
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
|
@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
true_max_s = max_s
|
true_max_s = max_s
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
|
@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user