Merge branch 'main' into ci_amd3

This commit is contained in:
fxmarty 2024-06-26 12:08:42 +02:00
commit 227f78f3fe
143 changed files with 6074 additions and 238 deletions

View File

@ -204,6 +204,8 @@ jobs:
needs: [build-and-push, prepare_integration_tests]
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
@ -250,4 +252,4 @@ jobs:
echo "DOCKER_VOLUME:"
echo $DOCKER_VOLUME
pytest -s -vvvvv integration-tests
pytest -s -vvvvv integration-tests ${PYTEST_FLAGS}

View File

@ -22,5 +22,5 @@ jobs:
- name: Run tests
run: |
pip install pytest pytest-asyncio
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
make python-client-tests

View File

@ -37,5 +37,5 @@ jobs:
export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ inputs.docker_image }}
export DOCKER_DEVICES=${{ inputs.docker_devices }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv integration-tests

View File

@ -28,7 +28,7 @@ jobs:
- name: Start starcoder
run: |
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
sleep 10
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health

View File

@ -72,7 +72,7 @@ jobs:
- name: Run server tests
run: |
pip install pytest
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests
- name: Pre-commit checks
run: |

View File

@ -145,6 +145,13 @@ COPY server/marlin/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Lorax Punica kernels
FROM kernel-builder as lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
# CMD ["--json-output"]

View File

@ -1,3 +1,5 @@
ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src
@ -37,7 +39,8 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
USER root
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
# Install server
COPY proto proto
@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image
FROM base
# Text Generation Inference base image for Intel-cpu
FROM ubuntu:22.04 as cpu
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
ca-certificates \
make \
g++ \
git \
wget \
cmake
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV KMP_BLOCKTIME=1
ENV KMP_TPAUSE=0
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_intel.txt && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} as final
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -153,7 +153,7 @@ this will impact performance.
### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
overridden with the `--otlp-service-name` argument
### Architecture

View File

@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
})
.collect();

View File

@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Downloading tokenizer");
// Parse Huggingface hub token
let auth_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
let auth_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime

View File

@ -455,6 +455,6 @@ class DeployedModel(BaseModel):
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
# with model_ prefixes, since this disables guardrails for colliding fields:
# https://github.com/pydantic/pydantic/issues/9177
model_config = ConfigDict(protected_namespaces=())
model_config = ConfigDict(protected_namespaces=())
model_id: str
sha: str

View File

@ -60,6 +60,9 @@
- local: conceptual/speculation
title: Speculation (Medusa, ngram)
- local: conceptual/guidance
title: How Guidance Works (via outlines)
title: How Guidance Works (via outlines
- local: conceptual/lora
title: LoRA (Low-Rank Adaptation)
title: Conceptual Guides

View File

@ -416,6 +416,14 @@ Options:
[env: MAX_CLIENT_BATCH_SIZE=]
[default: 4]
```
## LORA_ADAPTERS
```shell
--lora-adapters <LORA_ADAPTERS>
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request
[env: LORA_ADAPTERS=]
```
## HELP
```shell

View File

@ -0,0 +1,65 @@
# LoRA (Low-Rank Adaptation)
## What is LoRA?
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
## How is it used?
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
- fine-tuning a language model on a small dataset
- fine-tuning a language model on a domain-specific dataset
- fine-tuning a language model on a dataset with limited labels
## Optimizing Inference with LoRA
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
## Serving multiple LoRA adapters with TGI
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
### Specifying LoRA models
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
```bash
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
In the server logs, you will see the following message:
```txt
Loading adapter weights into model: predibase/customer_support
Loading adapter weights into model: predibase/dbpedia
```
## Generate text
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
```json
curl 127.0.0.1:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "Hello who are you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/customer_support"
}
}'
```
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
An updated tutorial with detailed examples will be published soon. Stay tuned!

View File

@ -43,6 +43,26 @@ if SYSTEM is None:
)
def pytest_addoption(parser):
parser.addoption(
"--release", action="store_true", default=False, help="run release tests"
)
def pytest_configure(config):
config.addinivalue_line("markers", "release: mark test as a release-only test")
def pytest_collection_modifyitems(config, items):
if config.getoption("--release"):
# --release given in cli: do not skip release tests
return
skip_release = pytest.mark.skip(reason="need --release option to run")
for item in items:
if "release" in item.keywords:
item.add_marker(skip_release)
class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
ignore_logprob = False

View File

@ -15,6 +15,7 @@ async def bloom_560(bloom_560_handle):
return bloom_560_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_bloom_560m(bloom_560, response_snapshot):
@ -31,6 +32,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
@ -55,6 +57,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):

View File

@ -15,6 +15,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
return bloom_560m_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
@ -31,6 +32,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m_sharded_load(
bloom_560m_sharded, generate_load, response_snapshot

View File

@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
@pytest.mark.release
def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):

View File

@ -27,6 +27,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
return flash_llama_awq_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
@ -41,6 +42,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
@ -62,6 +64,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load(

View File

@ -26,6 +26,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
@is_flaky_async(max_attempts=5)
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda", "rocm")
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
@ -47,6 +48,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
@require_backend_async("cuda")
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot
@ -57,6 +59,7 @@ async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all(
[
r.generated_text

View File

@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
return flash_falcon_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon(flash_falcon, response_snapshot):
@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):

View File

@ -17,6 +17,7 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda", "xpu")
@ -29,6 +30,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda", "xpu")
@ -53,6 +55,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda", "xpu")

View File

@ -17,6 +17,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
return flash_gemma_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda", "xpu")
@ -31,6 +32,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda", "xpu")
@ -57,6 +59,7 @@ async def test_flash_gemma_gptq_all_params(
assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda", "xpu")

View File

@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
return flash_gpt2_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate(
@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load(

View File

@ -23,6 +23,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
@require_backend_async("cuda")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
@ -35,6 +36,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
@require_backend_async("cuda")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_all_params(
@ -62,6 +64,7 @@ async def test_flash_llama_exl2_all_params(
@require_backend_async("cuda")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_load(

View File

@ -15,6 +15,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
return flash_llama_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@is_flaky_async(max_attempts=5)
@ -31,6 +32,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda")
@ -55,6 +57,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda")

View File

@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
return flash_llama_gptq_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params(
@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_load(

View File

@ -20,6 +20,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
return flash_llama_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
@ -31,6 +32,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
@ -53,6 +55,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_load(

View File

@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot):
@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
return flash_neox_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate(
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
responses = await generate_load(

View File

@ -38,6 +38,7 @@ def get_cow_beach():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@require_backend_async("cuda", "xpu")
@ -50,6 +51,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):

View File

@ -17,6 +17,7 @@ async def flash_phi(flash_phi_handle):
return flash_phi_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_flash_phi(flash_phi, response_snapshot):
@ -29,6 +30,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_flash_phi_all_params(flash_phi, response_snapshot):
@ -53,6 +55,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
return flash_qwen2_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_qwen2(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate(
@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate(
@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)

View File

@ -15,6 +15,7 @@ async def flash_santacoder(flash_santacoder_handle):
return flash_santacoder_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda", "xpu")
async def test_flash_santacoder(flash_santacoder, response_snapshot):
@ -27,6 +28,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_santacoder_load(
flash_santacoder, generate_load, response_snapshot

View File

@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle):
return flash_starcoder_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder(flash_starcoder, response_snapshot):
@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle):
return flash_starcoder2_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2_load(

View File

@ -15,6 +15,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
return flash_starcoder_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@is_flaky_async(max_attempts=10)
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
@ -33,6 +34,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
assert response == generous_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@is_flaky_async(max_attempts=10)
async def test_flash_starcoder_gptq_default_params(
@ -55,6 +57,7 @@ async def test_flash_starcoder_gptq_default_params(
assert response == generous_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_flash_starcoder_gptq_load(

View File

@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
return non_flash_llama_grammar_handle.client
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):

View File

@ -22,6 +22,7 @@ async def llama_grammar(llama_grammar_handle):
return llama_grammar_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
@ -62,6 +63,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
assert chat_completion == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_grammar_response_format_llama_error_if_tools_not_installed(
llama_grammar,

View File

@ -45,6 +45,7 @@ async def test_idefics(idefics, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot):
@ -60,6 +61,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken()

View File

@ -28,6 +28,7 @@ async def flash_llava_next(flash_llava_next_handle):
return flash_llava_next_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
@ -43,6 +44,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
@ -66,6 +68,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_load(

View File

@ -15,6 +15,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
return fused_kernel_mamba_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_mamba(fused_kernel_mamba, response_snapshot):
@ -27,6 +28,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@ -54,6 +56,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@require_backend_async("cuda")
async def test_mamba_load(

View File

@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle):
return mpt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_mpt(mpt_sharded, response_snapshot):
response = await mpt_sharded.generate(
@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
responses = await generate_load(

View File

@ -14,6 +14,7 @@ async def mt0_base(mt0_base_handle):
return mt0_base_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_mt0_base(mt0_base, response_snapshot):
response = await mt0_base.generate(
@ -28,6 +29,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mt0_base_all_params(mt0_base, response_snapshot):
response = await mt0_base.generate(
@ -50,6 +52,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
responses = await generate_load(

View File

@ -15,6 +15,7 @@ async def neox(neox_handle):
return neox_handle.client
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox, response_snapshot):
@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox, generate_load, response_snapshot):

View File

@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle):
return neox_sharded_handle.client
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot):
@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle):
return t5_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_t5_sharded(t5_sharded, response_snapshot):
response = await t5_sharded.generate(
@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
responses = await generate_load(

View File

@ -452,6 +452,11 @@ struct Args {
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
/// startup that will be available to callers via the `adapter_id` field in a request.
#[clap(long, env)]
lora_adapters: Option<String>,
}
#[derive(Debug)]
@ -485,6 +490,7 @@ fn shard_manager(
max_total_tokens: usize,
max_batch_size: Option<usize>,
max_input_tokens: usize,
lora_adapters: Option<String>,
otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter,
@ -620,6 +626,11 @@ fn shard_manager(
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
// Lora Adapters
if let Some(lora_adapters) = lora_adapters {
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -762,7 +773,7 @@ fn num_cuda_devices() -> Option<usize> {
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
}
},
};
let n_devices = devices.split(',').count();
Some(n_devices)
@ -1060,6 +1071,7 @@ fn spawn_shards(
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
let lora_adapters = args.lora_adapters.clone();
thread::spawn(move || {
shard_manager(
model_id,
@ -1085,6 +1097,7 @@ fn spawn_shards(
max_total_tokens,
max_batch_size,
max_input_tokens,
lora_adapters,
otlp_endpoint,
otlp_service_name,
max_log_level,
@ -1225,7 +1238,6 @@ fn spawn_webserver(
router_args.push("--otlp-service-name".to_string());
router_args.push(otlp_service_name);
// CORS origins
for origin in args.cors_allow_origin.into_iter() {
router_args.push("--cors-allow-origin".to_string());

View File

@ -134,6 +134,8 @@ message Request {
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
}
message Batch {

View File

@ -177,6 +177,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;

View File

@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,

View File

@ -429,6 +429,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -351,6 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
adapter_id: entry.request.adapter_id.clone(),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -491,6 +492,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -1,15 +1,15 @@
use crate::infer::Infer;
use crate::{
default_parameters,
server::{generate_internal, ComputeType},
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Serialize, ToSchema,
};
use axum::extract::{Extension, Path};
use axum::response::{IntoResponse, Response};
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::Json;
use futures::stream::FuturesUnordered;
use futures::TryStreamExt;
use reqwest::header::HeaderMap;
use reqwest::StatusCode;
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk {
@ -64,8 +64,6 @@ pub struct MetadataServerResponse {
pub extensions: Vec<String>,
}
// Routes
#[utoipa::path(
post,
tag = "Text Generation Inference",
@ -76,13 +74,13 @@ pub struct MetadataServerResponse {
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
pub async fn kserve_health_live() -> Json<LiveResponse> {
let data = LiveResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
Json(data)
}
#[utoipa::path(
post,
get,
tag = "Text Generation Inference",
path = "/v2/health/ready",
responses(
@ -91,9 +89,9 @@ pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorRes
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
pub async fn kserve_health_ready() -> Json<ReadyResponse> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
Json(data)
}
#[utoipa::path(
@ -106,7 +104,7 @@ pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorRe
example = json!({"error": "No response"}))
)
)]
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
pub async fn kerve_server_metadata() -> Json<MetadataServerResponse> {
let data = MetadataServerResponse {
name: "text-generation-inference".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
@ -116,7 +114,7 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
"metrics".to_string(),
],
};
Ok((HeaderMap::new(), Json(data)).into_response())
Json(data)
}
#[utoipa::path(
@ -131,13 +129,30 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
)]
pub async fn kserve_model_metadata(
Path((model_name, model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
) -> Json<MetadataServerResponse> {
let data = MetadataServerResponse {
name: model_name,
version: model_version,
extensions: vec!["infer".to_string(), "ready".to_string()],
};
Ok((HeaderMap::new(), Json(data)).into_response())
Json(data)
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>,
) -> Json<ReadyResponse> {
let data = ReadyResponse { live: true };
Json(data)
}
#[utoipa::path(
@ -155,7 +170,7 @@ pub async fn kserve_model_infer(
infer: Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let id = payload.id.clone();
let str_inputs = payload
.inputs
@ -226,22 +241,5 @@ pub async fn kserve_model_infer(
outputs: output_chunks,
};
Ok((HeaderMap::new(), Json(inference_output)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
Ok((HeaderMap::new(), Json(inference_output)))
}

View File

@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>,
/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>,
}
fn default_max_new_tokens() -> Option<u32> {
@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters {
seed: None,
top_n_tokens: None,
grammar: None,
adapter_id: None,
}
}

View File

@ -159,7 +159,9 @@ async fn main() -> Result<(), RouterError> {
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads

View File

@ -673,6 +673,7 @@ async fn completions(
seed,
top_n_tokens: None,
grammar: None,
..Default::default()
},
})
.collect();
@ -1115,6 +1116,7 @@ async fn chat_completions(
seed,
top_n_tokens: req.top_logprobs,
grammar,
..Default::default()
},
};
@ -1764,12 +1766,12 @@ pub async fn run(
#[derive(OpenApi)]
#[openapi(
paths(
kserve_model_infer,
kserve_health_live,
kserve_health_ready,
kerve_server_metadata,
kserve_model_metadata,
kserve_model_metadata_ready,
kserve_model_infer,
),
components(schemas(
InferenceOutput,

View File

@ -202,6 +202,7 @@ impl Validation {
decoder_input_details,
top_n_tokens,
grammar,
adapter_id,
..
} = request.parameters;
@ -383,6 +384,7 @@ impl Validation {
parameters,
stopping_parameters,
top_n_tokens,
adapter_id,
})
}
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32,
pub adapter_id: Option<String>,
}
#[derive(Error, Debug)]

View File

@ -4,6 +4,7 @@ include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
include Makefile-lorax-punica
unit-tests:
pytest -s -vv -m "not private" tests

View File

@ -0,0 +1,12 @@
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
build-lorax-punica:
if [ ! -d 'lorax-punica' ]; then \
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
fi
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
cd lorax-punica && git submodule update --init --recursive
cd lorax-punica/server/punica_kernels && python setup.py build
install-lorax-punica: build-lorax-punica
cd lorax-punica/server/punica_kernels && python setup.py install

View File

@ -19,6 +19,23 @@ def gptq_marlin_gemm(
"""
...
def gptq_marlin_24_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_meta: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports 2:4 sparsity.
"""
...
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,

View File

@ -5,6 +5,7 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Marlin gemm with GPTQ compatibility");
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");

View File

@ -12,6 +12,13 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_meta,
torch::Tensor &b_scales,
torch::Tensor &workspace, int64_t num_bits,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);

View File

@ -0,0 +1,51 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace marlin_24 {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
template <int M_, int N_, int K_>
struct ShapeBase {
static constexpr int M = M_, N = N_, K = K_;
};
using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragM = Vec<uint, 1>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
} // namespace marlin_24

View File

@ -0,0 +1,136 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
namespace marlin_24 {
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const void* glob_ptr,
bool pred = true,
const bool zfill = false) {
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
} // namespace marlin_24

View File

@ -0,0 +1,191 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
#include <cudaTypedefs.h>
namespace marlin_24 {
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
// is not supported. On later versions of CUDA the version without ordered
// metadata results in the following warning:
// | Advisory: Modifier .sp::ordered_metadata should be used on instruction
// | mma instead of modifier .sp as it is expected to have substantially
// | reduced performance on some future architectures
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
#define MMA_SP_INST \
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#else
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#endif
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
const int psel) {
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) {
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
} else {
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
float c3) {
uint2 r;
asm("{\n\t"
".reg .f16 a, b, c, d; \n\t"
"cvt.rn.f16.f32 a, %2; \n\t"
"cvt.rn.f16.f32 b, %3; \n\t"
"cvt.rn.f16.f32 c, %4; \n\t"
"cvt.rn.f16.f32 d, %5; \n\t"
"mov.b32 %0, {a, b}; \n\t"
"mov.b32 %1, {c, d}; \n\t"
"}"
: "=r"(r.x), "=r"(r.y)
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
return r;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant_4bit(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
FragS& s0, float* c4, float* c5, float* c6,
float* c7, FragS& s1) {
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
}
} // namespace marlin_24

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ setup(
"marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu",
"marlin_kernels/sparse/marlin_24_cuda_kernel.cu",
"marlin_kernels/ext.cpp",
],
extra_compile_args=extra_compile_args,

View File

@ -17,7 +17,12 @@ def get_test_model():
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
"test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
)
return model

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
]

View File

@ -0,0 +1,44 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass

View File

@ -0,0 +1,482 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass
class LoraConfig(AdapterConfig):
r: int
target_modules: Optional[Union[List[str], str]]
fan_in_fan_out: bool
lora_alpha: int
use_rslora: bool
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)
@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
return cls(
base_model_name_or_path=hf_config.base_model_name_or_path,
r=hf_config.r,
target_modules=hf_config.target_modules,
fan_in_fan_out=hf_config.fan_in_fan_out,
lora_alpha=hf_config.lora_alpha,
use_rslora=(
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
),
)
class LoraWeights(AdapterWeights):
"""LoRA weights for a single adapter merged across all layers."""
def __init__(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b
@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b
def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
@classmethod
def load(
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
scale = get_scaling_factor(
config.lora_alpha,
config.r,
uses_rslora=config.use_rslora,
)
unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
),
config,
)
@dataclass
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor
# decode (bgmv)
indices: torch.Tensor
@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)
@classmethod
def key(cls) -> str:
return "lora"
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
max_rank = max(
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None
if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
batch_indices = [
idx if idx in rank_indices else -1 for idx in batch_indices
]
batch_indices = torch.tensor(
batch_indices, dtype=torch.int64, device=device
)
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v

View File

@ -0,0 +1,158 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type
import torch
@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor
# [num_adapters]
adapter_set: Set[int]
# [num_segments + 1]
adapter_segments: torch.Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]
class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass
@property
def speculative_tokens(self) -> int:
return 0
class BatchAdapterWeights(ABC):
@abstractclassmethod
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass
class LayerAdapterWeights:
"""Adapter weights that apply to a particular layer."""
def __init__(self):
self.adapter_weights: Dict[int, AdapterWeights] = {}
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
self.adapter_weights[adapter_idx] = weights
def remove_adapter(self, adapter_idx: int):
if adapter_idx not in self.adapter_weights:
return
del self.adapter_weights[adapter_idx]
@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0
def get_data(
self,
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights
batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
@dataclass
class AdapterBatchData:
meta: AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]
prefill: bool
@staticmethod
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get("lora")
if lora_data is None:
continue
for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)
return ranks
def layer_names(self) -> Set[str]:
return set(self.data.keys())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys
@property
def max_rank(self) -> int:
ranks = self.ranks()
return max(ranks) if len(ranks) > 0 else 0

View File

@ -79,6 +79,18 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
if len(lora_adapter_ids) > 0:
logger.warning(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
@ -93,6 +105,7 @@ def serve(
)
server.serve(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
@ -113,6 +126,7 @@ def download_weights(
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
merge_lora: bool = False,
):
# Remove default handler
logger.remove()
@ -143,18 +157,28 @@ def download_weights(
) is not None
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# TODO: maybe reverse the default value of merge_lora?
# currently by default we don't merge the weights with the base model
if merge_lora:
try:
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else:
try:
utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
except Exception:
pass
try:
import json

View File

@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.lora import (
LoraLinear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)

View File

@ -7,7 +7,7 @@ if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "xpu":
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

View File

@ -1,5 +1,6 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
SUPPORTS_WINDOWING = False
@ -56,8 +57,6 @@ def paged_attention(
input_lengths: torch.Tensor,
max_s: int,
):
query = query.contiguous()
block_size = value_cache.shape[3]
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
@ -67,7 +66,7 @@ def paged_attention(
softmax_scale,
block_tables,
input_lengths,
block_size,
BLOCK_SIZE,
max_s,
None,
)

View File

@ -1,6 +1,7 @@
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
import math
from typing import Optional
import torch
import torch.nn as nn
import awq_inference_engine # with CUDA kernels
@ -17,7 +18,9 @@ import awq_inference_engine # with CUDA kernels
class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
@ -35,10 +38,7 @@ class WQLinear(nn.Module):
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
if bias:
self.bias = bias
else:
self.bias = None
self.bias = bias
@torch.no_grad()
def forward(self, x):

View File

@ -82,18 +82,20 @@ elif SYSTEM == "rocm":
return super().forward(hidden_states), residual
elif SYSTEM == "xpu":
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
res_out = hidden_states
out = ipex.llm.functional.add_layer_norm(
residual, hidden_states, self.weight, self.bias, self.eps, True
residual,
hidden_states,
self.weight,
self.bias,
self.eps,
residual is not None,
)
if residual is not None:
res_out = residual
return out, res_out
return out, residual if residual is not None else hidden_states
class FastRMSNorm(nn.Module):
@ -109,19 +111,16 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
if SYSTEM == "xpu":
residual_out = hidden_states
if SYSTEM == "ipex":
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
True,
residual is not None,
)
if residual is not None:
residual_out = residual
return out, residual_out
return out, residual if residual is not None else hidden_states
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual

View File

@ -1,7 +1,6 @@
from typing import Optional
import torch
from torch.nn import functional as F
from text_generation_server.layers.marlin import GPTQMarlinLinear
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm":
@ -217,7 +216,7 @@ def get_linear(weight, bias, quantize):
qweight=weight.qweight,
qzeros=weight.qzeros,
scales=weight.scales,
bias=bias is not None,
bias=bias,
)
except ImportError:
raise NotImplementedError(
@ -225,6 +224,9 @@ def get_linear(weight, bias, quantize):
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Linear,
GPTQMarlin24Weight,
GPTQMarlinLinear,
GPTQMarlinWeight,
MarlinLinear,
MarlinWeight,
@ -235,6 +237,11 @@ def get_linear(weight, bias, quantize):
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear(
weight=weight,
bias=bias,
)
elif isinstance(weight, MarlinWeight):
linear = MarlinLinear(weight=weight, bias=bias)
else:

View File

@ -0,0 +1,286 @@
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
has_sgmv,
lora_a_sgmv_cutlass,
lora_b_sgmv_cutlass,
orient_for_rank,
)
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
class LoraLinear(nn.Module):
def __init__(
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
def forward_layer_type(
self,
result: torch.Tensor,
input: torch.Tensor,
adapter_data: "AdapterBatchData",
layer_type: str,
start_idx: int,
end_idx: int,
) -> torch.Tensor:
if adapter_data is None:
return result
data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get("lora") if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
# This approach ensures accurate LoRA application across various layer sizes and
# configurations, adapting to different model architectures and parallelization strategies.
#
# Example scenarios where this is necessary:
# 1. The adapter's size doesn't evenly divide across GPUs.
# 2. We're processing the last segment which might be smaller.
# 3. Different projection layers (q, k, v) have different sizes.
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
adapter_mask = (
(adapter_data.meta.adapter_indices == adapter_index)
.to(input.dtype)
.view(-1, 1)
)
layer_result = self.forward_lora(
input, data, adapter_index, adapter_mask
)
result[:, start_idx:end_idx] += layer_result
return result
def forward_lora(
self,
input: torch.Tensor,
data: "BatchLoraWeights",
adapter_index: int,
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a
if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)
result = (a_out @ lora_b) * adapter_mask
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Implemented in subclasses")
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# noop if no layer names are provided (e.g. for models without adapters)
if self.layer_names is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = self.forward_layer_type(
result, input, adapter_data, layer_name, start_idx, end_idx
)
if is_3d:
result = result.reshape(prev_shape)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(a_out) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(gathered_tensors, a_out)
return torch.cat(gathered_tensors, dim=1)
class TensorParallelAdapterRowLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_name, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_name = layer_name
@classmethod
def load(cls, base_layer, layer_id, layer_name, process_group):
return cls(base_layer, layer_id, layer_name, process_group)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
if self.layer_name is None:
return result
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = result.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
self.forward_layer_type(
result, input, adapter_data, self.layer_name, start_idx, end_idx
)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
torch.distributed.all_reduce(a_out, group=self.process_group)
return a_out

View File

@ -1,9 +1,8 @@
from dataclasses import dataclass
from typing import Optional, Tuple, List
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
try:
@ -177,12 +176,12 @@ class GPTQMarlinLinear(nn.Module):
self.bits = weight.bits
self.is_full_k = weight.is_full_k
self.register_buffer("qweight", weight.qweight)
self.register_buffer("scales", weight.scales)
self.register_buffer("g_idx", weight.g_idx)
self.register_buffer("perm", weight.perm)
self.qweight = weight.qweight
self.scales = weight.scales
self.g_idx = weight.g_idx
self.perm = weight.perm
if bias is not None:
self.register_buffer("bias", bias)
self.bias = bias
else:
self.bias = None
@ -215,6 +214,116 @@ class GPTQMarlinLinear(nn.Module):
return C
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
@dataclass
class GPTQMarlin24Weight:
"""
GPTQ-Marlin 2:4 weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
B_meta (torch.Tensor): metadata for 2:4 sparsity.
s (torch.Tensor): float16 scales.
bits: quantized weight size.
"""
B: torch.Tensor
B_meta: torch.Tensor
s: torch.Tensor
bits: int
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
if weight.bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
)
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
out_features = weight.s.shape[1]
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
supported_sizes = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
raise RuntimeError(
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
)
self.bits = weight.bits
weights_per_int32 = 32 // self.bits
assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
assert (
out_features % weights_per_int32 == 0
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
assert (
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
if groupsize != -1 and in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
)
self.B = weight.B
self.B_meta = weight.B_meta
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
dtype=torch.int,
device=weight.B.device,
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.B_meta,
self.s,
self.workspace,
self.bits,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C
@dataclass
class MarlinWeight:
"""
@ -262,10 +371,10 @@ class MarlinLinear(nn.Module):
128,
}, f"Group size must be -1 or 128, was {groupsize}"
self.register_buffer("B", weight.B)
self.register_buffer("s", weight.s)
self.B = weight.B
self.s = weight.s
if bias is not None:
self.register_buffer("bias", bias)
self.bias = bias
else:
self.bias = None

View File

@ -9,7 +9,7 @@ if SYSTEM == "cuda":
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
elif SYSTEM == "xpu":
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif SYSTEM == "xpu":
elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)

View File

@ -3,6 +3,10 @@ from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class LayerConcat(torch.nn.Module):
@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if SYSTEM == "ipex":
ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer):
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
torch.distributed.all_reduce(out, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out

View File

@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from typing import Optional, List
from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -255,6 +255,7 @@ for data in ModelType:
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
@ -602,6 +603,7 @@ def get_model(
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))

View File

@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -538,6 +538,7 @@ class CausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,

View File

@ -22,7 +22,7 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.layers.attention import (
@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,

View File

@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)

View File

@ -19,6 +19,7 @@
# limitations under the License.
from typing import List, Optional, Tuple
import torch
import torch.distributed
@ -37,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
@ -50,43 +53,61 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights):
def load_attention(config, prefix, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
sizes = None
prefixes = None
# if specific model type, load the correct attention
if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
prefix = f"{prefix}.qkv_proj"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
prefix=prefix,
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
prefix=prefix,
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
else:
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
@ -120,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -144,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -188,11 +219,14 @@ class FlashLlamaAttention(torch.nn.Module):
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, index):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
@ -207,29 +241,54 @@ class LlamaMLP(nn.Module):
),
)
)
prefixes = None
sizes = None
# Fuse gate and up proj
bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
bias=bias,
)
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
prefixes = [f"gate_proj", f"up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=bias,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
@ -237,7 +296,7 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
@ -251,20 +310,27 @@ class LlamaMLP(nn.Module):
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, index, prefix, config, weights):
super().__init__()
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -287,6 +353,7 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -301,6 +368,7 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
)
# faster post attention rms norm
@ -308,7 +376,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
@ -323,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
@ -358,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -380,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -421,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -434,9 +506,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load_multi(
query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
)
self.mlp = MistralMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
)
# faster post attention rms norm
@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
)
for layer_id in range(config.num_hidden_layers)
]
@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s,
true_max_s,
prefill_cache_indices,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -26,7 +26,7 @@ import numpy as np
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:

View File

@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.gpt_neox(
input_ids,

View File

@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.

View File

@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,

View File

@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:

View File

@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,

Some files were not shown because too many files have changed in this diff Show More