mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into ci_amd3
This commit is contained in:
commit
227f78f3fe
4
.github/workflows/build.yaml
vendored
4
.github/workflows/build.yaml
vendored
@ -204,6 +204,8 @@ jobs:
|
||||
needs: [build-and-push, prepare_integration_tests]
|
||||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
env:
|
||||
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@ -250,4 +252,4 @@ jobs:
|
||||
echo "DOCKER_VOLUME:"
|
||||
echo $DOCKER_VOLUME
|
||||
|
||||
pytest -s -vvvvv integration-tests
|
||||
pytest -s -vvvvv integration-tests ${PYTEST_FLAGS}
|
||||
|
2
.github/workflows/client-tests.yaml
vendored
2
.github/workflows/client-tests.yaml
vendored
@ -22,5 +22,5 @@ jobs:
|
||||
- name: Run tests
|
||||
run: |
|
||||
pip install pytest pytest-asyncio
|
||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||
make python-client-tests
|
||||
|
2
.github/workflows/integration_tests.yaml
vendored
2
.github/workflows/integration_tests.yaml
vendored
@ -37,5 +37,5 @@ jobs:
|
||||
export DOCKER_VOLUME=/mnt/cache
|
||||
export DOCKER_IMAGE=${{ inputs.docker_image }}
|
||||
export DOCKER_DEVICES=${{ inputs.docker_devices }}
|
||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||
pytest -s -vv integration-tests
|
||||
|
2
.github/workflows/load_test.yaml
vendored
2
.github/workflows/load_test.yaml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
- name: Start starcoder
|
||||
run: |
|
||||
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
||||
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
||||
sleep 10
|
||||
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
||||
|
||||
|
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@ -72,7 +72,7 @@ jobs:
|
||||
- name: Run server tests
|
||||
run: |
|
||||
pip install pytest
|
||||
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||
pytest -s -vv server/tests
|
||||
- name: Pre-commit checks
|
||||
run: |
|
||||
|
10
Dockerfile
10
Dockerfile
@ -145,6 +145,13 @@ COPY server/marlin/ .
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
|
||||
# Build Lorax Punica kernels
|
||||
FROM kernel-builder as lorax-punica-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-lorax-punica Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
|
||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from marlin kernels builder
|
||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
# CMD ["--json-output"]
|
||||
|
@ -1,3 +1,5 @@
|
||||
ARG PLATFORM=xpu
|
||||
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
@ -37,7 +39,8 @@ RUN cargo build --profile release-opt
|
||||
|
||||
|
||||
# Text Generation Inference base image for Intel
|
||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base
|
||||
|
||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
|
||||
|
||||
USER root
|
||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||
@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
|
||||
WORKDIR /usr/src
|
||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# Final image
|
||||
FROM base
|
||||
|
||||
# Text Generation Inference base image for Intel-cpu
|
||||
FROM ubuntu:22.04 as cpu
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
ca-certificates \
|
||||
make \
|
||||
g++ \
|
||||
git \
|
||||
wget \
|
||||
cmake
|
||||
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTHON_VERSION='3.10.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
RUN conda install -c conda-forge gperftools mkl
|
||||
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
||||
|
||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
||||
|
||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||
|
||||
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
|
||||
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
||||
ENV KMP_BLOCKTIME=1
|
||||
ENV KMP_TPAUSE=0
|
||||
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
|
||||
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
|
||||
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_intel.txt && \
|
||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
FROM ${PLATFORM} as final
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -153,7 +153,7 @@ this will impact performance.
|
||||
### Distributed Tracing
|
||||
|
||||
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
||||
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
|
||||
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
|
||||
overridden with the `--otlp-service-name` argument
|
||||
|
||||
### Architecture
|
||||
|
@ -157,6 +157,7 @@ async fn prefill(
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing::info!("Downloading tokenizer");
|
||||
|
||||
// Parse Huggingface hub token
|
||||
let auth_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
|
||||
let auth_token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.ok();
|
||||
|
||||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
|
@ -455,6 +455,6 @@ class DeployedModel(BaseModel):
|
||||
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
|
||||
# with model_ prefixes, since this disables guardrails for colliding fields:
|
||||
# https://github.com/pydantic/pydantic/issues/9177
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_id: str
|
||||
sha: str
|
||||
|
@ -60,6 +60,9 @@
|
||||
- local: conceptual/speculation
|
||||
title: Speculation (Medusa, ngram)
|
||||
- local: conceptual/guidance
|
||||
title: How Guidance Works (via outlines)
|
||||
title: How Guidance Works (via outlines
|
||||
- local: conceptual/lora
|
||||
title: LoRA (Low-Rank Adaptation)
|
||||
|
||||
|
||||
title: Conceptual Guides
|
||||
|
@ -416,6 +416,14 @@ Options:
|
||||
[env: MAX_CLIENT_BATCH_SIZE=]
|
||||
[default: 4]
|
||||
|
||||
```
|
||||
## LORA_ADAPTERS
|
||||
```shell
|
||||
--lora-adapters <LORA_ADAPTERS>
|
||||
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request
|
||||
|
||||
[env: LORA_ADAPTERS=]
|
||||
|
||||
```
|
||||
## HELP
|
||||
```shell
|
||||
|
65
docs/source/conceptual/lora.md
Normal file
65
docs/source/conceptual/lora.md
Normal file
@ -0,0 +1,65 @@
|
||||
# LoRA (Low-Rank Adaptation)
|
||||
|
||||
## What is LoRA?
|
||||
|
||||
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
|
||||
|
||||
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
|
||||
|
||||
## How is it used?
|
||||
|
||||
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
|
||||
|
||||
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
|
||||
|
||||
- fine-tuning a language model on a small dataset
|
||||
- fine-tuning a language model on a domain-specific dataset
|
||||
- fine-tuning a language model on a dataset with limited labels
|
||||
|
||||
## Optimizing Inference with LoRA
|
||||
|
||||
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
|
||||
|
||||
## Serving multiple LoRA adapters with TGI
|
||||
|
||||
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
|
||||
|
||||
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
|
||||
|
||||
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
|
||||
|
||||
### Specifying LoRA models
|
||||
|
||||
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
In the server logs, you will see the following message:
|
||||
|
||||
```txt
|
||||
Loading adapter weights into model: predibase/customer_support
|
||||
Loading adapter weights into model: predibase/dbpedia
|
||||
```
|
||||
|
||||
## Generate text
|
||||
|
||||
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
|
||||
|
||||
```json
|
||||
curl 127.0.0.1:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "Hello who are you?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"adapter_id": "predibase/customer_support"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
||||
|
||||
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
@ -43,6 +43,26 @@ if SYSTEM is None:
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--release", action="store_true", default=False, help="run release tests"
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "release: mark test as a release-only test")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--release"):
|
||||
# --release given in cli: do not skip release tests
|
||||
return
|
||||
skip_release = pytest.mark.skip(reason="need --release option to run")
|
||||
for item in items:
|
||||
if "release" in item.keywords:
|
||||
item.add_marker(skip_release)
|
||||
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
rtol = 0.2
|
||||
ignore_logprob = False
|
||||
|
@ -15,6 +15,7 @@ async def bloom_560(bloom_560_handle):
|
||||
return bloom_560_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_bloom_560m(bloom_560, response_snapshot):
|
||||
@ -31,6 +32,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
||||
@ -55,6 +57,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
|
||||
|
@ -15,6 +15,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
|
||||
return bloom_560m_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
||||
@ -31,6 +32,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_bloom_560m_sharded_load(
|
||||
bloom_560m_sharded, generate_load, response_snapshot
|
||||
|
@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
|
||||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
def test_flash_llama_completion_single_prompt(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||
response = requests.post(
|
||||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
async def test_flash_llama_completion_many_prompts_stream(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
|
@ -27,6 +27,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
|
||||
return flash_llama_awq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||
response = await flash_llama_awq.generate(
|
||||
@ -41,6 +42,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||
response = await flash_llama_awq.generate(
|
||||
@ -62,6 +64,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
|
@ -26,6 +26,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
||||
|
||||
|
||||
@is_flaky_async(max_attempts=5)
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda", "rocm")
|
||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||
@ -47,6 +48,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_awq_load_sharded(
|
||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||
@ -57,6 +59,7 @@ async def test_flash_llama_awq_load_sharded(
|
||||
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all(
|
||||
[
|
||||
r.generated_text
|
||||
|
@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
|
||||
return flash_falcon_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_falcon(flash_falcon, response_snapshot):
|
||||
@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
||||
@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
|
||||
|
@ -17,6 +17,7 @@ async def flash_gemma(flash_gemma_handle):
|
||||
return flash_gemma_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@ -29,6 +30,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@ -53,6 +55,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda", "xpu")
|
||||
|
@ -17,6 +17,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
||||
return flash_gemma_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@ -31,6 +32,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
|
||||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@ -57,6 +59,7 @@ async def test_flash_gemma_gptq_all_params(
|
||||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda", "xpu")
|
||||
|
@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
|
||||
return flash_gpt2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
||||
response = await flash_gpt2.generate(
|
||||
@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
|
@ -23,6 +23,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||
@ -35,6 +36,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_all_params(
|
||||
@ -62,6 +64,7 @@ async def test_flash_llama_exl2_all_params(
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_load(
|
||||
|
@ -15,6 +15,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
|
||||
return flash_llama_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@is_flaky_async(max_attempts=5)
|
||||
@ -31,6 +32,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda")
|
||||
@ -55,6 +57,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda")
|
||||
|
@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
||||
return flash_llama_gptq_marlin_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
||||
@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_marlin_all_params(
|
||||
@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_marlin_load(
|
||||
|
@ -20,6 +20,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
|
||||
return flash_llama_marlin_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
||||
@ -31,6 +32,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
|
||||
@ -53,6 +55,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_marlin_load(
|
||||
|
@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
|
||||
return flash_neox_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_neox(flash_neox, response_snapshot):
|
||||
@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
||||
|
@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
|
||||
return flash_neox_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
||||
response = await flash_neox_sharded.generate(
|
||||
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
|
@ -38,6 +38,7 @@ def get_cow_beach():
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@ -50,6 +51,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
||||
|
@ -17,6 +17,7 @@ async def flash_phi(flash_phi_handle):
|
||||
return flash_phi_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_flash_phi(flash_phi, response_snapshot):
|
||||
@ -29,6 +30,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||
@ -53,6 +55,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||
|
@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
|
||||
return flash_qwen2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
||||
response = await flash_qwen2.generate(
|
||||
@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
||||
response = await flash_qwen2.generate(
|
||||
@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
|
||||
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
|
||||
|
@ -15,6 +15,7 @@ async def flash_santacoder(flash_santacoder_handle):
|
||||
return flash_santacoder_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda", "xpu")
|
||||
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
||||
@ -27,6 +28,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_santacoder_load(
|
||||
flash_santacoder, generate_load, response_snapshot
|
||||
|
@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle):
|
||||
return flash_starcoder_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
||||
@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
||||
@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):
|
||||
|
@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle):
|
||||
return flash_starcoder2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||
@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||
@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder2_load(
|
||||
|
@ -15,6 +15,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
||||
return flash_starcoder_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@is_flaky_async(max_attempts=10)
|
||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||
@ -33,6 +34,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
||||
assert response == generous_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@is_flaky_async(max_attempts=10)
|
||||
async def test_flash_starcoder_gptq_default_params(
|
||||
@ -55,6 +57,7 @@ async def test_flash_starcoder_gptq_default_params(
|
||||
assert response == generous_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_flash_starcoder_gptq_load(
|
||||
|
@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
||||
return non_flash_llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
||||
|
@ -22,6 +22,7 @@ async def llama_grammar(llama_grammar_handle):
|
||||
return llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||
|
||||
@ -62,6 +63,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
|
||||
assert chat_completion == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||
llama_grammar,
|
||||
|
@ -45,6 +45,7 @@ async def test_idefics(idefics, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_idefics_two_images(idefics, response_snapshot):
|
||||
@ -60,6 +61,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
|
@ -28,6 +28,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
||||
return flash_llava_next_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||
@ -43,6 +44,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||
@ -66,6 +68,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_load(
|
||||
|
@ -15,6 +15,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||
return fused_kernel_mamba_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||
@ -27,6 +28,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||
@ -54,6 +56,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda")
|
||||
async def test_mamba_load(
|
||||
|
@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle):
|
||||
return mpt_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mpt(mpt_sharded, response_snapshot):
|
||||
response = await mpt_sharded.generate(
|
||||
@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
|
@ -14,6 +14,7 @@ async def mt0_base(mt0_base_handle):
|
||||
return mt0_base_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mt0_base(mt0_base, response_snapshot):
|
||||
response = await mt0_base.generate(
|
||||
@ -28,6 +29,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||
response = await mt0_base.generate(
|
||||
@ -50,6 +52,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
|
@ -15,6 +15,7 @@ async def neox(neox_handle):
|
||||
return neox_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox(neox, response_snapshot):
|
||||
@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox_load(neox, generate_load, response_snapshot):
|
||||
|
@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle):
|
||||
return neox_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox(neox_sharded, response_snapshot):
|
||||
@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
|
||||
|
@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle):
|
||||
return t5_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_t5_sharded(t5_sharded, response_snapshot):
|
||||
response = await t5_sharded.generate(
|
||||
@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
|
@ -452,6 +452,11 @@ struct Args {
|
||||
/// Control the maximum number of inputs that a client can send in a single request
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
|
||||
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
|
||||
/// startup that will be available to callers via the `adapter_id` field in a request.
|
||||
#[clap(long, env)]
|
||||
lora_adapters: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -485,6 +490,7 @@ fn shard_manager(
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
max_input_tokens: usize,
|
||||
lora_adapters: Option<String>,
|
||||
otlp_endpoint: Option<String>,
|
||||
otlp_service_name: String,
|
||||
log_level: LevelFilter,
|
||||
@ -620,6 +626,11 @@ fn shard_manager(
|
||||
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||
}
|
||||
|
||||
// Lora Adapters
|
||||
if let Some(lora_adapters) = lora_adapters {
|
||||
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
|
||||
}
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
@ -762,7 +773,7 @@ fn num_cuda_devices() -> Option<usize> {
|
||||
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
|
||||
Ok(devices) => devices,
|
||||
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
|
||||
}
|
||||
},
|
||||
};
|
||||
let n_devices = devices.split(',').count();
|
||||
Some(n_devices)
|
||||
@ -1060,6 +1071,7 @@ fn spawn_shards(
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
let max_batch_size = args.max_batch_size;
|
||||
let lora_adapters = args.lora_adapters.clone();
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
@ -1085,6 +1097,7 @@ fn spawn_shards(
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_input_tokens,
|
||||
lora_adapters,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
max_log_level,
|
||||
@ -1225,7 +1238,6 @@ fn spawn_webserver(
|
||||
router_args.push("--otlp-service-name".to_string());
|
||||
router_args.push(otlp_service_name);
|
||||
|
||||
|
||||
// CORS origins
|
||||
for origin in args.cors_allow_origin.into_iter() {
|
||||
router_args.push("--cors-allow-origin".to_string());
|
||||
|
@ -134,6 +134,8 @@ message Request {
|
||||
repeated uint32 blocks = 9;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional string adapter_id = 11;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -177,6 +177,7 @@ impl Client {
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
|
@ -244,6 +244,7 @@ impl Health for ShardedClient {
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
|
@ -429,6 +429,7 @@ mod tests {
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
@ -351,6 +351,7 @@ impl State {
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
@ -491,6 +492,7 @@ mod tests {
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
@ -1,15 +1,15 @@
|
||||
use crate::infer::Infer;
|
||||
use crate::{
|
||||
default_parameters,
|
||||
server::{generate_internal, ComputeType},
|
||||
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
|
||||
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Serialize, ToSchema,
|
||||
};
|
||||
use axum::extract::{Extension, Path};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::StatusCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct OutputChunk {
|
||||
@ -64,8 +64,6 @@ pub struct MetadataServerResponse {
|
||||
pub extensions: Vec<String>,
|
||||
}
|
||||
|
||||
// Routes
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
@ -76,13 +74,13 @@ pub struct MetadataServerResponse {
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
pub async fn kserve_health_live() -> Json<LiveResponse> {
|
||||
let data = LiveResponse { live: true };
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
Json(data)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/health/ready",
|
||||
responses(
|
||||
@ -91,9 +89,9 @@ pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorRes
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
pub async fn kserve_health_ready() -> Json<ReadyResponse> {
|
||||
let data = ReadyResponse { live: true };
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
Json(data)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
@ -106,7 +104,7 @@ pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorRe
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
pub async fn kerve_server_metadata() -> Json<MetadataServerResponse> {
|
||||
let data = MetadataServerResponse {
|
||||
name: "text-generation-inference".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@ -116,7 +114,7 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
|
||||
"metrics".to_string(),
|
||||
],
|
||||
};
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
Json(data)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
@ -131,13 +129,30 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
|
||||
)]
|
||||
pub async fn kserve_model_metadata(
|
||||
Path((model_name, model_version)): Path<(String, String)>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
) -> Json<MetadataServerResponse> {
|
||||
let data = MetadataServerResponse {
|
||||
name: model_name,
|
||||
version: model_version,
|
||||
extensions: vec!["infer".to_string(), "ready".to_string()],
|
||||
};
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
Json(data)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||
responses(
|
||||
(status = 200, description = "Model version is ready", body = ReadyResponse),
|
||||
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_model_metadata_ready(
|
||||
Path((_model_name, _model_version)): Path<(String, String)>,
|
||||
) -> Json<ReadyResponse> {
|
||||
let data = ReadyResponse { live: true };
|
||||
Json(data)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
@ -155,7 +170,7 @@ pub async fn kserve_model_infer(
|
||||
infer: Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(payload): Json<InferenceRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||
let id = payload.id.clone();
|
||||
let str_inputs = payload
|
||||
.inputs
|
||||
@ -226,22 +241,5 @@ pub async fn kserve_model_infer(
|
||||
outputs: output_chunks,
|
||||
};
|
||||
|
||||
Ok((HeaderMap::new(), Json(inference_output)).into_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||
responses(
|
||||
(status = 200, description = "Model version is ready", body = ReadyResponse),
|
||||
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_model_metadata_ready(
|
||||
Path((_model_name, _model_version)): Path<(String, String)>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let data = ReadyResponse { live: true };
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
Ok((HeaderMap::new(), Json(inference_output)))
|
||||
}
|
||||
|
@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub grammar: Option<GrammarType>,
|
||||
|
||||
/// Lora adapter id
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,7 +159,9 @@ async fn main() -> Result<(), RouterError> {
|
||||
});
|
||||
|
||||
// Parse Huggingface hub token
|
||||
let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
|
||||
let authorization_token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.ok();
|
||||
|
||||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
|
@ -673,6 +673,7 @@ async fn completions(
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
..Default::default()
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
@ -1115,6 +1116,7 @@ async fn chat_completions(
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
@ -1764,12 +1766,12 @@ pub async fn run(
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(
|
||||
kserve_model_infer,
|
||||
kserve_health_live,
|
||||
kserve_health_ready,
|
||||
kerve_server_metadata,
|
||||
kserve_model_metadata,
|
||||
kserve_model_metadata_ready,
|
||||
kserve_model_infer,
|
||||
),
|
||||
components(schemas(
|
||||
InferenceOutput,
|
||||
|
@ -202,6 +202,7 @@ impl Validation {
|
||||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
grammar,
|
||||
adapter_id,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
@ -383,6 +384,7 @@ impl Validation {
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
top_n_tokens,
|
||||
adapter_id,
|
||||
})
|
||||
}
|
||||
|
||||
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||
pub parameters: ValidParameters,
|
||||
pub stopping_parameters: ValidStoppingParameters,
|
||||
pub top_n_tokens: u32,
|
||||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -4,6 +4,7 @@ include Makefile-vllm
|
||||
include Makefile-awq
|
||||
include Makefile-eetq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
|
12
server/Makefile-lorax-punica
Normal file
12
server/Makefile-lorax-punica
Normal file
@ -0,0 +1,12 @@
|
||||
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
||||
|
||||
build-lorax-punica:
|
||||
if [ ! -d 'lorax-punica' ]; then \
|
||||
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
|
||||
fi
|
||||
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
||||
cd lorax-punica && git submodule update --init --recursive
|
||||
cd lorax-punica/server/punica_kernels && python setup.py build
|
||||
|
||||
install-lorax-punica: build-lorax-punica
|
||||
cd lorax-punica/server/punica_kernels && python setup.py install
|
@ -19,6 +19,23 @@ def gptq_marlin_gemm(
|
||||
"""
|
||||
...
|
||||
|
||||
def gptq_marlin_24_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_meta: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Matrix multiplication using Marlin kernels. This is an extension of
|
||||
`marlin_gemm` that supports 2:4 sparsity.
|
||||
"""
|
||||
...
|
||||
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
|
@ -5,6 +5,7 @@
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||
"Marlin gemm with GPTQ compatibility");
|
||||
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
|
||||
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||
"Repack GPTQ parameters for Marlin");
|
||||
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
||||
|
@ -12,6 +12,13 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full);
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_meta,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
51
server/marlin/marlin_kernels/sparse/common/base.h
Normal file
51
server/marlin/marlin_kernels/sparse/common/base.h
Normal file
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||
// for instance as inputs to tensor core operations. Consequently, all
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
template <int M_, int N_, int K_>
|
||||
struct ShapeBase {
|
||||
static constexpr int M = M_, N = N_, K = K_;
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragM = Vec<uint, 1>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>; // quantization scales
|
||||
|
||||
} // namespace marlin_24
|
136
server/marlin/marlin_kernels/sparse/common/mem.h
Normal file
136
server/marlin/marlin_kernels/sparse/common/mem.h
Normal file
@ -0,0 +1,136 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "base.h"
|
||||
|
||||
namespace marlin_24 {
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
||||
const void* glob_ptr,
|
||||
bool pred = true,
|
||||
const bool zfill = false) {
|
||||
const int BYTES = 16;
|
||||
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Asynchronous global->shared copy
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
||||
: "=r"(a[0]), "=r"(a[1])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
int state = -1;
|
||||
do
|
||||
// Guarantee that subsequent writes by this threadblock will be visible
|
||||
// globally.
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||
: "=r"(state)
|
||||
: "l"(lock));
|
||||
while (state != count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Release barrier and increment visitation count.
|
||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (reset) {
|
||||
lock[0] = 0;
|
||||
return;
|
||||
}
|
||||
int val = 1;
|
||||
// Make sure that all writes since acquiring this barrier are visible
|
||||
// globally, while releasing the barrier.
|
||||
asm volatile("fence.acq_rel.gpu;\n");
|
||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||
:
|
||||
: "l"(lock), "r"(val));
|
||||
}
|
||||
}
|
||||
} // namespace marlin_24
|
191
server/marlin/marlin_kernels/sparse/common/mma.h
Normal file
191
server/marlin/marlin_kernels/sparse/common/mma.h
Normal file
@ -0,0 +1,191 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "base.h"
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
|
||||
// is not supported. On later versions of CUDA the version without ordered
|
||||
// metadata results in the following warning:
|
||||
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
|
||||
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
|
||||
// | reduced performance on some future architectures
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
#define MMA_SP_INST \
|
||||
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
#else
|
||||
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
#endif
|
||||
|
||||
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
||||
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
|
||||
const int psel) {
|
||||
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
|
||||
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
if (psel == 0) {
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
} else {
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
|
||||
float c3) {
|
||||
uint2 r;
|
||||
asm("{\n\t"
|
||||
".reg .f16 a, b, c, d; \n\t"
|
||||
"cvt.rn.f16.f32 a, %2; \n\t"
|
||||
"cvt.rn.f16.f32 b, %3; \n\t"
|
||||
"cvt.rn.f16.f32 c, %4; \n\t"
|
||||
"cvt.rn.f16.f32 d, %5; \n\t"
|
||||
"mov.b32 %0, {a, b}; \n\t"
|
||||
"mov.b32 %1, {c, d}; \n\t"
|
||||
"}"
|
||||
: "=r"(r.x), "=r"(r.y)
|
||||
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||
return r;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__device__ inline FragB dequant_4bit(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__device__ inline FragB dequant_8bit(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
||||
frag_b[0] = __hmul2(frag_b[0], s);
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
|
||||
FragS& s0, float* c4, float* c5, float* c6,
|
||||
float* c7, FragS& s1) {
|
||||
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
||||
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
||||
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
||||
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
|
||||
|
||||
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
|
||||
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
|
||||
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
|
||||
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
||||
}
|
||||
|
||||
} // namespace marlin_24
|
1125
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
Normal file
1125
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -12,6 +12,7 @@ setup(
|
||||
"marlin_kernels/gptq_marlin.cu",
|
||||
"marlin_kernels/gptq_marlin_repack.cu",
|
||||
"marlin_kernels/marlin_cuda_kernel.cu",
|
||||
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu",
|
||||
"marlin_kernels/ext.cpp",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
|
@ -17,7 +17,12 @@ def get_test_model():
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
|
||||
|
||||
model = TestModel(
|
||||
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
|
||||
"test_model_id",
|
||||
torch.nn.Linear(1, 1),
|
||||
tokenizer,
|
||||
False,
|
||||
torch.float32,
|
||||
torch.device("cpu"),
|
||||
)
|
||||
return model
|
||||
|
||||
|
1152
server/tests/utils/test_weights.py
Normal file
1152
server/tests/utils/test_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
13
server/text_generation_server/adapters/__init__.py
Normal file
13
server/text_generation_server/adapters/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/__init__.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchData,
|
||||
AdapterBatchMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AdapterBatchData",
|
||||
"AdapterBatchMetadata",
|
||||
]
|
44
server/text_generation_server/adapters/config.py
Normal file
44
server/text_generation_server/adapters/config.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/config.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from text_generation_server.adapters.weights import AdapterWeights
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleMap:
|
||||
module_name: str
|
||||
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig(ABC):
|
||||
base_model_name_or_path: str
|
||||
|
||||
@abstractmethod
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: ModuleMap,
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
pass
|
482
server/text_generation_server/adapters/lora.py
Normal file
482
server/text_generation_server/adapters/lora.py
Normal file
@ -0,0 +1,482 @@
|
||||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/lora.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig as _LoraConfig
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchMetadata,
|
||||
AdapterWeights,
|
||||
BatchAdapterWeights,
|
||||
)
|
||||
from text_generation_server.utils.sgmv import (
|
||||
BGMV_MAX_RANK,
|
||||
MAX_RANK_CUSTOM,
|
||||
get_tmp_tensors,
|
||||
orient_for_rank,
|
||||
pad_rank,
|
||||
use_cutlass_shrink,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||
block_size = size // world_size
|
||||
start = offset + rank * block_size
|
||||
stop = offset + (rank + 1) * block_size
|
||||
return start, stop
|
||||
|
||||
|
||||
def shard_on_dim(
|
||||
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
size = t.shape[dim]
|
||||
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||
|
||||
if dim == 0:
|
||||
tensor = t[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = t[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def shard_lora_weights(
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
split_dim: int,
|
||||
process_group: ProcessGroup,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
# [hidden_size, r]
|
||||
weights_a = [
|
||||
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
||||
]
|
||||
|
||||
# [r, hidden_size]
|
||||
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
||||
|
||||
return weights_a, weights_b
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(AdapterConfig):
|
||||
r: int
|
||||
target_modules: Optional[Union[List[str], str]]
|
||||
fan_in_fan_out: bool
|
||||
lora_alpha: int
|
||||
use_rslora: bool
|
||||
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
adapter_weight_names = set()
|
||||
module_map = {}
|
||||
for weight_name in weight_names:
|
||||
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
||||
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
||||
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
||||
continue
|
||||
|
||||
module_map[weight_name] = {
|
||||
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
||||
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
||||
}
|
||||
adapter_weight_names.add(lora_a_name)
|
||||
adapter_weight_names.add(lora_b_name)
|
||||
return module_map, adapter_weight_names
|
||||
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
return LoraWeights.load(
|
||||
self,
|
||||
model,
|
||||
module_map,
|
||||
layer_type,
|
||||
unused_weight_names,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||
return cls(
|
||||
base_model_name_or_path=hf_config.base_model_name_or_path,
|
||||
r=hf_config.r,
|
||||
target_modules=hf_config.target_modules,
|
||||
fan_in_fan_out=hf_config.fan_in_fan_out,
|
||||
lora_alpha=hf_config.lora_alpha,
|
||||
use_rslora=(
|
||||
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraWeights(AdapterWeights):
|
||||
"""LoRA weights for a single adapter merged across all layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
adapter_config: LoraConfig,
|
||||
):
|
||||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||
|
||||
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||
self._is_transposed = False
|
||||
|
||||
# [num_layers, hidden_size, r]
|
||||
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||
self._weights_a = torch.stack(weights_a)
|
||||
|
||||
# [num_layers, r, hidden_size]
|
||||
self._weights_b = torch.stack(weights_b)
|
||||
|
||||
self.adapter_config = adapter_config
|
||||
|
||||
@property
|
||||
def weights_a(self) -> torch.Tensor:
|
||||
if self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_a
|
||||
|
||||
@property
|
||||
def weights_b(self) -> torch.Tensor:
|
||||
if self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_b
|
||||
|
||||
@property
|
||||
def weights_a_t(self) -> torch.Tensor:
|
||||
if not self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_a
|
||||
|
||||
@property
|
||||
def weights_b_t(self) -> torch.Tensor:
|
||||
if not self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_b
|
||||
|
||||
def _transpose_weights(self):
|
||||
if self._use_cutlass_shrink:
|
||||
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
||||
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
||||
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
||||
self._is_transposed = not self._is_transposed
|
||||
|
||||
@classmethod
|
||||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||
return [BatchLoraWeights]
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
config: LoraConfig,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
) -> Optional[AdapterWeights]:
|
||||
nlayers = model.get_num_layers_for_type(layer_type)
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
weight_name, layer = model.target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
||||
if weight_name not in module_map:
|
||||
# There is no LoRA weight for this layer type in the adapter
|
||||
return None
|
||||
|
||||
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||
lora_a = lora_a.to(base_device, model.dtype)
|
||||
|
||||
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||
lora_b = lora_b.to(base_device, model.dtype)
|
||||
|
||||
scale = get_scaling_factor(
|
||||
config.lora_alpha,
|
||||
config.r,
|
||||
uses_rslora=config.use_rslora,
|
||||
)
|
||||
|
||||
unused_weight_names.discard(lora_a_name)
|
||||
unused_weight_names.discard(lora_b_name)
|
||||
|
||||
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
||||
# (A * B) * C = A * (B * C)
|
||||
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||
|
||||
# pad lora ranks to be compatible with sgmv
|
||||
lora_a_list = [
|
||||
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
||||
]
|
||||
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
padded_rank = lora_a_list[0].size(1)
|
||||
config.r = padded_rank
|
||||
|
||||
return LoraWeights(
|
||||
*shard_lora_weights(
|
||||
weights_a=lora_a_list,
|
||||
weights_b=lora_b_list,
|
||||
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
||||
process_group=model.process_group,
|
||||
),
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankSegments:
|
||||
rank: int
|
||||
|
||||
lora_a_ptr: torch.Tensor
|
||||
lora_b_ptr: torch.Tensor
|
||||
|
||||
# prefill (sgmv)
|
||||
tmp_shrink: torch.Tensor
|
||||
tmp_expand: torch.Tensor
|
||||
segment_starts: torch.Tensor
|
||||
segment_ends: torch.Tensor
|
||||
|
||||
# decode (bgmv)
|
||||
indices: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchLoraWeights(BatchAdapterWeights):
|
||||
lora_a: Dict[int, torch.Tensor]
|
||||
lora_b: Dict[int, torch.Tensor]
|
||||
adapter_index_configs: Dict[int, LoraConfig]
|
||||
rank_data: Dict[int, RankSegments]
|
||||
use_sgmv: bool
|
||||
|
||||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
return adapter_index in self.adapter_index_configs
|
||||
|
||||
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||
return all(
|
||||
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||
for rank_data in self.rank_data.values()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> str:
|
||||
return "lora"
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
meta: AdapterBatchMetadata,
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> Optional["BatchLoraWeights"]:
|
||||
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||
adapter_weights = {
|
||||
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||
}
|
||||
if not adapter_weights:
|
||||
return None
|
||||
|
||||
first_weights = next(iter(adapter_weights.values()))
|
||||
device = first_weights.weights_a.device
|
||||
segment_indices = meta.segment_indices
|
||||
|
||||
lora_a = {
|
||||
idx: adapter_weights[idx].weights_a
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
lora_b = {
|
||||
idx: adapter_weights[idx].weights_b
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
|
||||
max_rank = max(
|
||||
(
|
||||
adapter_weights[idx].lora_a_r
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
),
|
||||
default=0,
|
||||
)
|
||||
|
||||
if prefill or max_rank > BGMV_MAX_RANK:
|
||||
use_sgmv = True
|
||||
lora_a_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_a.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
lora_b_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_b.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
use_sgmv = False
|
||||
lora_a_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_a_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
lora_b_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_b_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
adapter_index_configs = {
|
||||
idx: adapter_weights[idx].adapter_config
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
|
||||
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||
|
||||
rank_indices = defaultdict(list)
|
||||
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||
if adapter_idx not in adapter_weights:
|
||||
continue
|
||||
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||
|
||||
if prefill_head_indices is not None:
|
||||
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||
for head_index in prefill_head_indices:
|
||||
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||
if head_index < meta.adapter_segments[j]:
|
||||
prefill_head_segment_ends[-1] += 1
|
||||
else:
|
||||
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||
j += 1
|
||||
|
||||
rank_data = {}
|
||||
for rank, indices in rank_indices.items():
|
||||
tmp_shrink = None
|
||||
tmp_expand = None
|
||||
segment_starts = None
|
||||
segment_ends = None
|
||||
batch_indices = None
|
||||
|
||||
if use_sgmv:
|
||||
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||
lora_a_ptr_indices.size(0), rank, device
|
||||
)
|
||||
segment_starts = meta.adapter_segments[indices]
|
||||
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
||||
if prefill_head_indices is not None:
|
||||
for i, segment_index in enumerate(indices):
|
||||
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
||||
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
||||
else:
|
||||
rank_indices = set(indices)
|
||||
batch_indices = [
|
||||
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||
]
|
||||
batch_indices = [
|
||||
idx if idx in rank_indices else -1 for idx in batch_indices
|
||||
]
|
||||
batch_indices = torch.tensor(
|
||||
batch_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
rank_data[rank] = RankSegments(
|
||||
rank=rank,
|
||||
tmp_shrink=tmp_shrink,
|
||||
tmp_expand=tmp_expand,
|
||||
lora_a_ptr=lora_a_ptr[indices],
|
||||
lora_b_ptr=lora_b_ptr[indices],
|
||||
segment_starts=segment_starts,
|
||||
segment_ends=segment_ends,
|
||||
indices=batch_indices,
|
||||
)
|
||||
|
||||
return BatchLoraWeights(
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
adapter_index_configs=adapter_index_configs,
|
||||
rank_data=rank_data,
|
||||
use_sgmv=use_sgmv,
|
||||
)
|
||||
|
||||
|
||||
def get_scaling_factor(
|
||||
lora_alpha: int,
|
||||
r: int,
|
||||
uses_rslora: bool = False,
|
||||
) -> float:
|
||||
"""Computes the scaling factor for the lora weights."""
|
||||
if uses_rslora:
|
||||
return lora_alpha / (r**0.5)
|
||||
return lora_alpha / r
|
||||
|
||||
|
||||
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
||||
if hasattr(v, "lora_weights"):
|
||||
return v.lora_weights
|
||||
return v
|
158
server/text_generation_server/adapters/weights.py
Normal file
158
server/text_generation_server/adapters/weights.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/weights.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from abc import ABC, abstractclassmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterBatchMetadata:
|
||||
# [batch_size]
|
||||
adapter_indices: torch.Tensor
|
||||
|
||||
# [num_adapters]
|
||||
adapter_set: Set[int]
|
||||
|
||||
# [num_segments + 1]
|
||||
adapter_segments: torch.Tensor
|
||||
|
||||
# [num_segments]
|
||||
# maps from segment index to adapter index, i.e.:
|
||||
# segment_indices[s] == adapter_indices[i]
|
||||
segment_indices: List[int]
|
||||
|
||||
|
||||
class AdapterWeights(ABC):
|
||||
@abstractclassmethod
|
||||
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def speculative_tokens(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class BatchAdapterWeights(ABC):
|
||||
@abstractclassmethod
|
||||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def key(cls) -> str:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def load(
|
||||
cls,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
meta: "AdapterBatchMetadata",
|
||||
prefill: bool,
|
||||
prefill_head_indices: torch.Tensor,
|
||||
) -> Optional["BatchAdapterWeights"]:
|
||||
pass
|
||||
|
||||
|
||||
class LayerAdapterWeights:
|
||||
"""Adapter weights that apply to a particular layer."""
|
||||
|
||||
def __init__(self):
|
||||
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
||||
|
||||
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
||||
self.adapter_weights[adapter_idx] = weights
|
||||
|
||||
def remove_adapter(self, adapter_idx: int):
|
||||
if adapter_idx not in self.adapter_weights:
|
||||
return
|
||||
del self.adapter_weights[adapter_idx]
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
adapter_weights.speculative_tokens
|
||||
for adapter_weights in self.adapter_weights.values()
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.adapter_weights) == 0
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
meta: AdapterBatchMetadata,
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> Dict[str, BatchAdapterWeights]:
|
||||
# bucket adapters by batch class
|
||||
adapter_batch_types: Dict[
|
||||
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
||||
] = defaultdict(dict)
|
||||
for adapter_index, adapter_weights in self.adapter_weights.items():
|
||||
for batch_type in adapter_weights.get_batch_types():
|
||||
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
||||
|
||||
batch_data = {}
|
||||
for batch_type, adapter_weights in adapter_batch_types.items():
|
||||
batched_weights = batch_type.load(
|
||||
adapter_weights, meta, prefill, prefill_head_indices
|
||||
)
|
||||
if batched_weights is not None:
|
||||
batch_data[batch_type.key()] = batched_weights
|
||||
return batch_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterBatchData:
|
||||
meta: AdapterBatchMetadata
|
||||
|
||||
# layer type -> adapter type -> batch weight data
|
||||
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
||||
|
||||
prefill: bool
|
||||
|
||||
@staticmethod
|
||||
def from_meta(
|
||||
meta: AdapterBatchMetadata,
|
||||
weights: Dict[str, LayerAdapterWeights],
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> "AdapterBatchData":
|
||||
data = {}
|
||||
for k, v in weights.items():
|
||||
if v.is_empty():
|
||||
continue
|
||||
data[k] = v.get_data(
|
||||
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
||||
)
|
||||
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||
|
||||
def ranks(self) -> Set[int]:
|
||||
# TODO(travis): refactor to be less coupled to lora implementation
|
||||
ranks = set()
|
||||
for layer_data in self.data.values():
|
||||
lora_data = layer_data.get("lora")
|
||||
if lora_data is None:
|
||||
continue
|
||||
|
||||
for rank_data in lora_data.rank_data.values():
|
||||
ranks.add(rank_data.rank)
|
||||
|
||||
return ranks
|
||||
|
||||
def layer_names(self) -> Set[str]:
|
||||
return set(self.data.keys())
|
||||
|
||||
def adapter_keys(self) -> Set[str]:
|
||||
adapter_keys = set()
|
||||
for layer_data in self.data.values():
|
||||
adapter_keys.update(layer_data.keys())
|
||||
return adapter_keys
|
||||
|
||||
@property
|
||||
def max_rank(self) -> int:
|
||||
ranks = self.ranks()
|
||||
return max(ranks) if len(ranks) > 0 else 0
|
@ -79,6 +79,18 @@ def serve(
|
||||
if otlp_endpoint is not None:
|
||||
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||
|
||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||
|
||||
# split on comma and strip whitespace
|
||||
lora_adapter_ids = (
|
||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
logger.warning(
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||
)
|
||||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
@ -93,6 +105,7 @@ def serve(
|
||||
)
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
@ -113,6 +126,7 @@ def download_weights(
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
merge_lora: bool = False,
|
||||
):
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
@ -143,18 +157,28 @@ def download_weights(
|
||||
) is not None
|
||||
|
||||
if not is_local_model:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="adapter_config.json"
|
||||
)
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local_model = True
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
# TODO: maybe reverse the default value of merge_lora?
|
||||
# currently by default we don't merge the weights with the base model
|
||||
if merge_lora:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="adapter_config.json"
|
||||
)
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local_model = True
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
utils.peft.download_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
import json
|
||||
|
@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
||||
# Just to add the `load` methods.
|
||||
from text_generation_server.layers.layernorm import load_layer_norm
|
||||
from text_generation_server.layers.conv import load_conv2d
|
||||
|
||||
from text_generation_server.layers.lora import (
|
||||
LoraLinear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
|
@ -7,7 +7,7 @@ if SYSTEM == "cuda":
|
||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
elif SYSTEM == "xpu":
|
||||
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
@ -1,5 +1,6 @@
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
@ -56,8 +57,6 @@ def paged_attention(
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
):
|
||||
query = query.contiguous()
|
||||
block_size = value_cache.shape[3]
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
@ -67,7 +66,7 @@ def paged_attention(
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
)
|
@ -1,6 +1,7 @@
|
||||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import awq_inference_engine # with CUDA kernels
|
||||
@ -17,7 +18,9 @@ import awq_inference_engine # with CUDA kernels
|
||||
|
||||
|
||||
class WQLinear(nn.Module):
|
||||
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
||||
def __init__(
|
||||
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if w_bit not in [4]:
|
||||
@ -35,10 +38,7 @@ class WQLinear(nn.Module):
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
if bias:
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
self.bias = bias
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
|
@ -82,18 +82,20 @@ elif SYSTEM == "rocm":
|
||||
|
||||
return super().forward(hidden_states), residual
|
||||
|
||||
elif SYSTEM == "xpu":
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
res_out = hidden_states
|
||||
out = ipex.llm.functional.add_layer_norm(
|
||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||
residual,
|
||||
hidden_states,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps,
|
||||
residual is not None,
|
||||
)
|
||||
if residual is not None:
|
||||
res_out = residual
|
||||
return out, res_out
|
||||
return out, residual if residual is not None else hidden_states
|
||||
|
||||
|
||||
class FastRMSNorm(nn.Module):
|
||||
@ -109,19 +111,16 @@ class FastRMSNorm(nn.Module):
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if SYSTEM == "xpu":
|
||||
residual_out = hidden_states
|
||||
if SYSTEM == "ipex":
|
||||
out = ipex.llm.functional.add_rms_norm(
|
||||
residual,
|
||||
hidden_states,
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
True,
|
||||
residual is not None,
|
||||
)
|
||||
if residual is not None:
|
||||
residual_out = residual
|
||||
return out, residual_out
|
||||
return out, residual if residual is not None else hidden_states
|
||||
elif hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from text_generation_server.layers.marlin import GPTQMarlinLinear
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
@ -217,7 +216,7 @@ def get_linear(weight, bias, quantize):
|
||||
qweight=weight.qweight,
|
||||
qzeros=weight.qzeros,
|
||||
scales=weight.scales,
|
||||
bias=bias is not None,
|
||||
bias=bias,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
@ -225,6 +224,9 @@ def get_linear(weight, bias, quantize):
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Linear,
|
||||
GPTQMarlin24Weight,
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
MarlinLinear,
|
||||
MarlinWeight,
|
||||
@ -235,6 +237,11 @@ def get_linear(weight, bias, quantize):
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, GPTQMarlin24Weight):
|
||||
linear = GPTQMarlin24Linear(
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
)
|
||||
elif isinstance(weight, MarlinWeight):
|
||||
linear = MarlinLinear(weight=weight, bias=bias)
|
||||
else:
|
||||
|
286
server/text_generation_server/layers/lora.py
Normal file
286
server/text_generation_server/layers/lora.py
Normal file
@ -0,0 +1,286 @@
|
||||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from accelerate import init_empty_weights
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.utils.sgmv import (
|
||||
add_lora_a_bgmv,
|
||||
add_lora_b_bgmv,
|
||||
has_sgmv,
|
||||
lora_a_sgmv_cutlass,
|
||||
lora_b_sgmv_cutlass,
|
||||
orient_for_rank,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
def __init__(
|
||||
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||
):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.layer_id = layer_id
|
||||
self.process_group = process_group
|
||||
|
||||
def forward_layer_type(
|
||||
self,
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
adapter_data: "AdapterBatchData",
|
||||
layer_type: str,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get("lora") if data is not None else None
|
||||
)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
# The 'result' tensor represents the full output, which can vary in size based on
|
||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||
# This approach ensures accurate LoRA application across various layer sizes and
|
||||
# configurations, adapting to different model architectures and parallelization strategies.
|
||||
#
|
||||
# Example scenarios where this is necessary:
|
||||
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||
# 2. We're processing the last segment which might be smaller.
|
||||
# 3. Different projection layers (q, k, v) have different sizes.
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
v = lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
else:
|
||||
for adapter_index in adapter_data.meta.adapter_set:
|
||||
if data is not None and data.has_adapter(adapter_index):
|
||||
adapter_mask = (
|
||||
(adapter_data.meta.adapter_indices == adapter_index)
|
||||
.to(input.dtype)
|
||||
.view(-1, 1)
|
||||
)
|
||||
layer_result = self.forward_lora(
|
||||
input, data, adapter_index, adapter_mask
|
||||
)
|
||||
result[:, start_idx:end_idx] += layer_result
|
||||
|
||||
return result
|
||||
|
||||
def forward_lora(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
data: "BatchLoraWeights",
|
||||
adapter_index: int,
|
||||
adapter_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||
|
||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||
|
||||
a_out = input @ lora_a
|
||||
if self.process_group.size() > 1:
|
||||
a_out = self.collect_lora_a(a_out)
|
||||
|
||||
result = (a_out @ lora_b) * adapter_mask
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("Implemented in subclasses")
|
||||
|
||||
|
||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_names = layer_names
|
||||
self.sizes = sizes
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
return TensorParallelMultiAdapterLinear(
|
||||
base_layer, layer_id, layer_names, sizes, process_group
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
# noop if no layer names are provided (e.g. for models without adapters)
|
||||
if self.layer_names is None:
|
||||
return result
|
||||
|
||||
# handle models like Bloom that have inputs of shape
|
||||
# (batch_size, sequence_length, hidden_size)
|
||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||
# for the LoRA computation, then reshape back
|
||||
prev_shape = result.shape
|
||||
is_3d = len(input.shape) >= 3
|
||||
if is_3d:
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
result = result.reshape(-1, result.shape[-1])
|
||||
|
||||
offset = 0
|
||||
for i, layer_name in enumerate(self.layer_names):
|
||||
start_idx = offset // self.process_group.size()
|
||||
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||
# different projection sizes across layers and model architectures.
|
||||
if self.sizes is not None:
|
||||
offset += self.sizes[i]
|
||||
end_idx = offset // self.process_group.size()
|
||||
else:
|
||||
end_idx = result.shape[1]
|
||||
|
||||
result = self.forward_layer_type(
|
||||
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
if is_3d:
|
||||
result = result.reshape(prev_shape)
|
||||
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
gathered_tensors = [
|
||||
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||
return torch.cat(gathered_tensors, dim=1)
|
||||
|
||||
|
||||
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_name = layer_name
|
||||
|
||||
@classmethod
|
||||
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||
return cls(base_layer, layer_id, layer_name, process_group)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
if self.layer_name is None:
|
||||
return result
|
||||
|
||||
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||
stride = result.shape[-1] // self.process_group.size()
|
||||
start_idx = self.process_group.rank() * stride
|
||||
end_idx = (self.process_group.rank() + 1) * stride
|
||||
|
||||
self.forward_layer_type(
|
||||
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||
return a_out
|
@ -1,9 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
try:
|
||||
@ -177,12 +176,12 @@ class GPTQMarlinLinear(nn.Module):
|
||||
self.bits = weight.bits
|
||||
self.is_full_k = weight.is_full_k
|
||||
|
||||
self.register_buffer("qweight", weight.qweight)
|
||||
self.register_buffer("scales", weight.scales)
|
||||
self.register_buffer("g_idx", weight.g_idx)
|
||||
self.register_buffer("perm", weight.perm)
|
||||
self.qweight = weight.qweight
|
||||
self.scales = weight.scales
|
||||
self.g_idx = weight.g_idx
|
||||
self.perm = weight.perm
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@ -215,6 +214,116 @@ class GPTQMarlinLinear(nn.Module):
|
||||
return C
|
||||
|
||||
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlin24Weight:
|
||||
"""
|
||||
GPTQ-Marlin 2:4 weights.
|
||||
|
||||
Attributes:
|
||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||
B_meta (torch.Tensor): metadata for 2:4 sparsity.
|
||||
s (torch.Tensor): float16 scales.
|
||||
bits: quantized weight size.
|
||||
"""
|
||||
|
||||
B: torch.Tensor
|
||||
B_meta: torch.Tensor
|
||||
s: torch.Tensor
|
||||
bits: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.B.dtype == torch.int32
|
||||
assert self.B_meta.dtype == torch.int16
|
||||
assert self.s.dtype == torch.float16
|
||||
|
||||
|
||||
class GPTQMarlin24Linear(nn.Module):
|
||||
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
|
||||
if weight.bits not in GPTQ_MARLIN_BITS:
|
||||
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
||||
raise RuntimeError(
|
||||
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
||||
)
|
||||
|
||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
|
||||
out_features = weight.s.shape[1]
|
||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
||||
|
||||
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
supported_sizes = ", ".join(
|
||||
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
|
||||
)
|
||||
|
||||
self.bits = weight.bits
|
||||
weights_per_int32 = 32 // self.bits
|
||||
|
||||
assert (
|
||||
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
|
||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
|
||||
assert (
|
||||
out_features % weights_per_int32 == 0
|
||||
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
|
||||
|
||||
assert (
|
||||
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
|
||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
|
||||
if groupsize != -1 and in_features % groupsize != 0:
|
||||
raise ValueError(
|
||||
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
||||
)
|
||||
|
||||
self.B = weight.B
|
||||
self.B_meta = weight.B_meta
|
||||
self.s = weight.s
|
||||
if bias is not None:
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
dtype=torch.int,
|
||||
device=weight.B.device,
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
|
||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.B,
|
||||
self.B_meta,
|
||||
self.s,
|
||||
self.workspace,
|
||||
self.bits,
|
||||
A.shape[0],
|
||||
self.s.shape[1],
|
||||
A.shape[1],
|
||||
)
|
||||
|
||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinWeight:
|
||||
"""
|
||||
@ -262,10 +371,10 @@ class MarlinLinear(nn.Module):
|
||||
128,
|
||||
}, f"Group size must be -1 or 128, was {groupsize}"
|
||||
|
||||
self.register_buffer("B", weight.B)
|
||||
self.register_buffer("s", weight.s)
|
||||
self.B = weight.B
|
||||
self.s = weight.s
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -9,7 +9,7 @@ if SYSTEM == "cuda":
|
||||
import rotary_emb
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm._C import ops
|
||||
elif SYSTEM == "xpu":
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
elif SYSTEM == "xpu":
|
||||
elif SYSTEM == "ipex":
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), True
|
||||
)
|
||||
|
@ -3,6 +3,10 @@ from torch.nn import functional as F
|
||||
from typing import Iterable, List
|
||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
class LayerConcat(torch.nn.Module):
|
||||
@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer):
|
||||
local_out = gather_input.T
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
|
||||
if input.shape[0] == 1:
|
||||
return world_out
|
||||
@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer):
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
else:
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
||||
@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1 and reduce:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
|
||||
)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce and self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
@ -255,6 +255,7 @@ for data in ModelType:
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
@ -602,6 +603,7 @@ def get_model(
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
|
@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
@ -538,6 +538,7 @@ class CausalLM(Model):
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -22,7 +22,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "xpu":
|
||||
if SYSTEM != "ipex":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
token_embeds = self.embed_tokens(input_ids)
|
||||
position_embeds = self.embed_positions(position_ids)
|
||||
|
@ -19,6 +19,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -37,6 +38,8 @@ from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
@ -50,43 +53,61 @@ if SYSTEM == "rocm":
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
bias = getattr(config, "attention_bias", False)
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = None
|
||||
prefixes = None
|
||||
|
||||
# if specific model type, load the correct attention
|
||||
if config.model_type == "phi3":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
prefix = f"{prefix}.qkv_proj"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
elif config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
prefix = f"{prefix}.W_pack"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
else:
|
||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer=base_layer,
|
||||
layer_id=layer_id,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
@ -120,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.index = index
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
index,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
@ -144,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
@ -188,11 +219,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, index):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
self.act = (
|
||||
@ -207,29 +241,54 @@ class LlamaMLP(nn.Module):
|
||||
),
|
||||
)
|
||||
)
|
||||
prefixes = None
|
||||
sizes = None
|
||||
|
||||
# Fuse gate and up proj
|
||||
bias = getattr(config, "mlp_bias", False)
|
||||
if config.model_type == "phi3":
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
prefixes = [f"gate_proj", f"up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
]
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
@ -237,7 +296,7 @@ class LlamaMLP(nn.Module):
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
@ -251,20 +310,27 @@ class LlamaMLP(nn.Module):
|
||||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||
return self.down_proj(out)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
@ -287,6 +353,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
@ -301,6 +368,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
@ -308,7 +376,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
@ -323,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
index=layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
@ -358,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -380,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -421,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
@ -434,9 +506,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
|
||||
|
||||
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||
query_key_value,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
layer_id,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
self.act = (
|
||||
@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
layer_id,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
layer_id,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
|
||||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||
return self.down_proj(out)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.mlp = MistralMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||
)
|
||||
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -26,7 +26,7 @@ import numpy as np
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "xpu":
|
||||
if SYSTEM != "ipex":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids,
|
||||
|
@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
# Unused here
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
|
@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user