diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 22fa06e3..80f258fa 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -156,6 +156,8 @@ jobs: needs: build-and-push runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' + env: + PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -178,6 +180,6 @@ jobs: export DOCKER_VOLUME=/mnt/cache export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} echo $DOCKER_IMAGE - pytest -s -vv integration-tests + pytest -s -vv integration-tests ${PYTEST_FLAGS} diff --git a/.github/workflows/client-tests.yaml b/.github/workflows/client-tests.yaml index ef7c217c..ff2928c4 100644 --- a/.github/workflows/client-tests.yaml +++ b/.github/workflows/client-tests.yaml @@ -22,5 +22,5 @@ jobs: - name: Run tests run: | pip install pytest pytest-asyncio - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} make python-client-tests diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 4e111afe..59a8d304 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -37,5 +37,5 @@ jobs: export DOCKER_VOLUME=/mnt/cache export DOCKER_IMAGE=${{ inputs.docker_image }} export DOCKER_DEVICES=${{ inputs.docker_devices }} - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -s -vv integration-tests diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index fd22e395..637df472 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -11,66 +11,24 @@ on: - 'main' jobs: - start-runner: - name: Start self-hosted EC2 runner - runs-on: ubuntu-latest - env: - AWS_REGION: eu-central-1 - EC2_AMI_ID: ami-0ab09c07cfd194259 - EC2_INSTANCE_TYPE: g5.12xlarge - EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326 - EC2_SECURITY_GROUP: sg-072f92ae3082936c6 - outputs: - label: ${{ steps.start-ec2-runner.outputs.label }} - ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Start EC2 runner - id: start-ec2-runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: start - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ${{ env.EC2_AMI_ID }} - ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} - subnet-id: ${{ env.EC2_SUBNET_ID }} - security-group-id: ${{ env.EC2_SECURITY_GROUP }} - aws-resource-tags: > # optional, requires additional permissions - [ - {"Key": "Name", "Value": "ec2-tgi-github-runner"}, - {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} - ] - load-tests: concurrency: group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true - needs: start-runner # required to start the main job when the runner is ready - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] env: DOCKER_VOLUME: /cache steps: - name: Checkout repository uses: actions/checkout@v3 - - name: Prepare disks - run: | - sudo mkfs -t ext4 /dev/nvme1n1 - sudo mkdir ${{ env.DOCKER_VOLUME }} - sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} - - name: Install k6 run: | curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1 - name: Start starcoder run: | - docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v ${{ env.DOCKER_VOLUME }}:/data -e HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 + docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 sleep 10 wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health @@ -82,27 +40,3 @@ jobs: if: ${{ always() }} run: | docker stop tgi-starcoder || true - - stop-runner: - name: Stop self-hosted EC2 runner - needs: - - start-runner - - load-tests - runs-on: ubuntu-latest - env: - AWS_REGION: eu-central-1 - if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Stop EC2 runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: stop - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index eb5a4657..f983b6ed 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -72,7 +72,7 @@ jobs: - name: Run server tests run: | pip install pytest - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -s -vv server/tests - name: Pre-commit checks run: | diff --git a/Dockerfile_intel b/Dockerfile_intel index 131f49ba..a41fbc1e 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,3 +1,5 @@ +ARG PLATFORM=xpu + FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src @@ -37,7 +39,8 @@ RUN cargo build --profile release-opt # Text Generation Inference base image for Intel -FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base + +FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu USER root # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it @@ -49,7 +52,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list -RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build +RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils # Text Generation Inference base env ENV HUGGINGFACE_HUB_CACHE=/data \ @@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed # Install server COPY proto proto @@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca # Install launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher -# Final image -FROM base +# Text Generation Inference base image for Intel-cpu +FROM ubuntu:22.04 as cpu + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + make \ + g++ \ + git \ + wget \ + cmake + +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN conda install -c conda-forge gperftools mkl + +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl + +WORKDIR /usr/src + +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a + +RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131 + +RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install + +RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . + +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so +ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric +ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib +ENV KMP_BLOCKTIME=1 +ENV KMP_TPAUSE=0 +ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist +ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist +ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_intel.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher + +FROM ${PLATFORM} as final ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/README.md b/README.md index 74616748..d60c7cde 100644 --- a/README.md +++ b/README.md @@ -105,14 +105,14 @@ The Swagger UI is also available at: [https://huggingface.github.io/text-generat ### Using a private or gated model -You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by +You have the option to utilize the `HF_TOKEN` environment variable for configuring the token employed by `text-generation-inference`. This allows you to gain access to protected resources. For example, if you want to serve the gated Llama V2 model variants: 1. Go to https://huggingface.co/settings/tokens 2. Copy your cli READ token -3. Export `HUGGING_FACE_HUB_TOKEN=` +3. Export `HF_TOKEN=` or with Docker: @@ -121,7 +121,7 @@ model=meta-llama/Llama-2-7b-chat-hf volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= -docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model +docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model ``` ### A note on Shared Memory (shm) @@ -153,7 +153,8 @@ this will impact performance. ### Distributed Tracing `text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature -by setting the address to an OTLP collector with the `--otlp-endpoint` argument. +by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be +overridden with the `--otlp-service-name` argument ### Architecture diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index b9d80b7a..2ee3d7c5 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -147,7 +147,9 @@ fn main() -> Result<(), Box> { tracing::info!("Downloading tokenizer"); // Parse Huggingface hub token - let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + let auth_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); // Download and instantiate tokenizer // We need to download it outside of the Tokio runtime diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index eb872ee6..a56edaca 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,5 +1,5 @@ from enum import Enum -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, ConfigDict from typing import Optional, List, Union, Any from text_generation.errors import ValidationError @@ -452,5 +452,9 @@ class StreamResponse(BaseModel): # Inference API currently deployed model class DeployedModel(BaseModel): + # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members + # with model_ prefixes, since this disables guardrails for colliding fields: + # https://github.com/pydantic/pydantic/issues/9177 + model_config = ConfigDict(protected_namespaces=()) model_id: str sha: str diff --git a/docs/source/architecture.md b/docs/source/architecture.md index b7885879..a8418817 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -70,6 +70,8 @@ Options: [env: JSON_OUTPUT=] --otlp-endpoint [env: OTLP_ENDPOINT=] + --otlp-service-name + [env: OTLP_SERVICE_NAME=] --cors-allow-origin [env: CORS_ALLOW_ORIGIN=] --ngrok @@ -138,6 +140,8 @@ Serve's command line parameters on the TGI repository are these: │ --logger-level TEXT [default: INFO] │ │ --json-output --no-json-output [default: no-json-output] │ │ --otlp-endpoint TEXT [default: None] │ +│ --otlp-service-name TEXT [default: │ +│ text-generation-inference...│ │ --help Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md index b49c59c9..ef3a1db7 100644 --- a/docs/source/basic_tutorials/gated_model_access.md +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -2,13 +2,13 @@ If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens) -If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable. For example: +If you're using the CLI, set the `HF_TOKEN` environment variable. For example: ``` -export HUGGING_FACE_HUB_TOKEN= +export HF_TOKEN= ``` -If you would like to do it through Docker, you can provide your token by specifying `HUGGING_FACE_HUB_TOKEN` as shown below. +If you would like to do it through Docker, you can provide your token by specifying `HF_TOKEN` as shown below. ```bash model=meta-llama/Llama-2-7b-chat-hf @@ -17,7 +17,7 @@ token= docker run --gpus all \ --shm-size 1g \ - -e HUGGING_FACE_HUB_TOKEN=$token \ + -e HF_TOKEN=$token \ -p 8080:80 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \ --model-id $model diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index a77f25a5..5e40146f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -336,6 +336,13 @@ Options: --otlp-endpoint [env: OTLP_ENDPOINT=] +``` +## OTLP_SERVICE_NAME +```shell + --otlp-service-name + [env: OTLP_SERVICE_NAME=] + [default: text-generation-inference.router] + ``` ## CORS_ALLOW_ORIGIN ```shell diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 0b239484..f5f38ac6 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -1,42 +1,62 @@ -import sys -import subprocess -import contextlib -import pytest import asyncio -import os -import docker +import contextlib import json import math +import os +import random +import re import shutil +import subprocess +import sys import tempfile import time -import random +from typing import Dict, List, Optional -from docker.errors import NotFound -from typing import Optional, List, Dict -from syrupy.extensions.json import JSONSnapshotExtension +import docker +import pytest from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError - +from docker.errors import NotFound +from syrupy.extensions.json import JSONSnapshotExtension from text_generation import AsyncClient from text_generation.types import ( - Response, - Details, - InputToken, - Token, BestOfSequence, - Grammar, ChatComplete, ChatCompletionChunk, ChatCompletionComplete, Completion, + Details, + Grammar, + InputToken, + Response, + Token, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) -HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) +HF_TOKEN = os.getenv("HF_TOKEN", None) DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") +def pytest_addoption(parser): + parser.addoption( + "--release", action="store_true", default=False, help="run release tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "release: mark test as a release-only test") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--release"): + # --release given in cli: do not skip release tests + return + skip_release = pytest.mark.skip(reason="need --release option to run") + for item in items: + if "release" in item.keywords: + item.add_marker(skip_release) + + class ResponseComparator(JSONSnapshotExtension): rtol = 0.2 ignore_logprob = False @@ -447,8 +467,8 @@ def launcher(event_loop): if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" - if HUGGING_FACE_HUB_TOKEN is not None: - env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN + if HF_TOKEN is not None: + env["HF_TOKEN"] = HF_TOKEN volumes = [] if DOCKER_VOLUME: diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index bdcbdc78..d413519e 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle): return bloom_560_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m(bloom_560, response_snapshot): response = await bloom_560.generate( @@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_all_params(bloom_560, response_snapshot): response = await bloom_560.generate( @@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index 3995f9e5..f9e8ed9c 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle): return bloom_560m_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): response = await bloom_560m_sharded.generate( @@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m_sharded_load( bloom_560m_sharded, generate_load, response_snapshot diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index cafa8ea6..0efb6693 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle): # method for it. Instead, we use the `requests` library to make the HTTP request directly. +@pytest.mark.release def test_flash_llama_completion_single_prompt( flash_llama_completion, response_snapshot ): @@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt( assert response == response_snapshot +@pytest.mark.release def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): response = requests.post( f"{flash_llama_completion.base_url}/v1/completions", @@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn assert response == response_snapshot +@pytest.mark.release async def test_flash_llama_completion_many_prompts_stream( flash_llama_completion, response_snapshot ): diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index ead918c3..b500b15d 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle): return flash_llama_awq_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( @@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( @@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index a83614ac..4cf9b171 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): return flash_llama_awq_handle_sharded.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): response = await flash_llama_awq_sharded.generate( @@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_llama_awq_load_sharded( flash_llama_awq_sharded, generate_load, response_snapshot diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py index eac91984..0fb40fe7 100644 --- a/integration-tests/models/test_flash_falcon.py +++ b/integration-tests/models/test_flash_falcon.py @@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle): return flash_falcon_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_falcon(flash_falcon, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_falcon_all_params(flash_falcon, response_snapshot): @@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py index 7ab43111..7bee8dea 100644 --- a/integration-tests/models/test_flash_gemma.py +++ b/integration-tests/models/test_flash_gemma.py @@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle): return flash_gemma_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma(flash_gemma, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_all_params(flash_gemma, response_snapshot): @@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py index 8ac5f5a1..79d4cf24 100644 --- a/integration-tests/models/test_flash_gemma_gptq.py +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle): return flash_gemma_gptq_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq_all_params( @@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params( assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_gptq_load( diff --git a/integration-tests/models/test_flash_gpt2.py b/integration-tests/models/test_flash_gpt2.py index 0c7977d0..cd73d0a3 100644 --- a/integration-tests/models/test_flash_gpt2.py +++ b/integration-tests/models/test_flash_gpt2.py @@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle): return flash_gpt2_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_gpt2(flash_gpt2, response_snapshot): response = await flash_gpt2.generate( @@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index 18319f60..7169c999 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle): return flash_llama_exl2_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): @@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_all_params( @@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params( assert response == ignore_logprob_response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_load( diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index b87f054b..135f4b05 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle): return flash_llama_gptq_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): @@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_load( diff --git a/integration-tests/models/test_flash_llama_gptq_marlin.py b/integration-tests/models/test_flash_llama_gptq_marlin.py index 9c37a644..2274abce 100644 --- a/integration-tests/models/test_flash_llama_gptq_marlin.py +++ b/integration-tests/models/test_flash_llama_gptq_marlin.py @@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): return flash_llama_gptq_marlin_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_marlin_all_params( @@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params( assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_gptq_marlin_load( diff --git a/integration-tests/models/test_flash_llama_marlin.py b/integration-tests/models/test_flash_llama_marlin.py index e7c5ccbd..a89a1e41 100644 --- a/integration-tests/models/test_flash_llama_marlin.py +++ b/integration-tests/models/test_flash_llama_marlin.py @@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle): return flash_llama_marlin_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot): @@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_marlin_load( diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index 0289c61d..31848dae 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle): return flash_neox_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox(flash_neox, response_snapshot): @@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_neox_sharded.py b/integration-tests/models/test_flash_neox_sharded.py index 8a491915..1f1e7225 100644 --- a/integration-tests/models/test_flash_neox_sharded.py +++ b/integration-tests/models/test_flash_neox_sharded.py @@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle): return flash_neox_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_neox(flash_neox_sharded, response_snapshot): response = await flash_neox_sharded.generate( @@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 6be1750c..3ead3150 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -34,6 +34,7 @@ def get_cow_beach(): return f"data:image/png;base64,{encoded_string.decode('utf-8')}" +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): @@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 9d6ca566..73bb5edc 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle): return flash_phi_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_phi(flash_phi, response_snapshot): response = await flash_phi.generate( @@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_phi_all_params(flash_phi, response_snapshot): response = await flash_phi.generate( @@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_qwen2.py b/integration-tests/models/test_flash_qwen2.py index 2963aeb4..c64f8732 100644 --- a/integration-tests/models/test_flash_qwen2.py +++ b/integration-tests/models/test_flash_qwen2.py @@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle): return flash_qwen2_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_qwen2(flash_qwen2, response_snapshot): response = await flash_qwen2.generate( @@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): response = await flash_qwen2.generate( @@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index 0f005f15..96a36aba 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -13,6 +13,7 @@ async def flash_santacoder(flash_santacoder_handle): return flash_santacoder_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_santacoder(flash_santacoder, response_snapshot): response = await flash_santacoder.generate( @@ -23,6 +24,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_santacoder_load( flash_santacoder, generate_load, response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 64e8b27c..dc5a8a53 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle): return flash_starcoder_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder(flash_starcoder, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot): @@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot): diff --git a/integration-tests/models/test_flash_starcoder2.py b/integration-tests/models/test_flash_starcoder2.py index ea665b6c..88341cfe 100644 --- a/integration-tests/models/test_flash_starcoder2.py +++ b/integration-tests/models/test_flash_starcoder2.py @@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle): return flash_starcoder2_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder2(flash_starcoder2, response_snapshot): @@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): @@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_starcoder2_load( diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index 329158b7..f1007d6e 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -13,6 +13,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle): return flash_starcoder_gptq_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): response = await flash_starcoder_gptq.generate( @@ -24,6 +25,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap assert response == generous_response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_starcoder_gptq_default_params( flash_starcoder_gptq, generous_response_snapshot @@ -40,6 +42,7 @@ async def test_flash_starcoder_gptq_default_params( assert response == generous_response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_flash_starcoder_gptq_load( flash_starcoder_gptq, generate_load, generous_response_snapshot diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index ce5da8a9..4face9e1 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle): return non_flash_llama_grammar_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot): diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index 9c4c048e..ea25fa1c 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -22,6 +22,7 @@ async def llama_grammar(llama_grammar_handle): return llama_grammar_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): @@ -62,6 +63,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh assert chat_completion == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_grammar_response_format_llama_error_if_tools_not_installed( llama_grammar, diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index ac807b76..b7725f0b 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -45,6 +45,7 @@ async def test_idefics(idefics, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_idefics_two_images(idefics, response_snapshot): @@ -60,6 +61,7 @@ async def test_idefics_two_images(idefics, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_idefics_load(idefics, generate_load, response_snapshot): chicken = get_chicken() diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py index f5b290b1..ea277d71 100644 --- a/integration-tests/models/test_llava_next.py +++ b/integration-tests/models/test_llava_next.py @@ -26,6 +26,7 @@ async def flash_llava_next(flash_llava_next_handle): return flash_llava_next_handle.client +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): @@ -41,6 +42,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): @@ -64,6 +66,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_load( diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index bf3701b4..bc946de8 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -13,6 +13,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle): return fused_kernel_mamba_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_mamba(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( @@ -24,6 +25,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( @@ -50,6 +52,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mamba_load( fused_kernel_mamba, generate_load, generous_response_snapshot diff --git a/integration-tests/models/test_mpt.py b/integration-tests/models/test_mpt.py index d58a8c5a..1832910a 100644 --- a/integration-tests/models/test_mpt.py +++ b/integration-tests/models/test_mpt.py @@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle): return mpt_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_mpt(mpt_sharded, response_snapshot): response = await mpt_sharded.generate( @@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mpt_load(mpt_sharded, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index c877056a..e53d8ed4 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -13,6 +13,7 @@ async def mt0_base(mt0_base_handle): return mt0_base_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_mt0_base(mt0_base, response_snapshot): response = await mt0_base.generate( @@ -27,6 +28,7 @@ async def test_mt0_base(mt0_base, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mt0_base_all_params(mt0_base, response_snapshot): response = await mt0_base.generate( @@ -49,6 +51,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_mt0_base_load(mt0_base, generate_load, response_snapshot): responses = await generate_load( diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py index 7b88f86a..ee60441d 100644 --- a/integration-tests/models/test_neox.py +++ b/integration-tests/models/test_neox.py @@ -15,6 +15,7 @@ async def neox(neox_handle): return neox_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox(neox, response_snapshot): @@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox_load(neox, generate_load, response_snapshot): diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py index 8cee8765..a69227c9 100644 --- a/integration-tests/models/test_neox_sharded.py +++ b/integration-tests/models/test_neox_sharded.py @@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle): return neox_sharded_handle.client +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox(neox_sharded, response_snapshot): @@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.skip @pytest.mark.asyncio async def test_neox_load(neox_sharded, generate_load, response_snapshot): diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py index 4b4cfd98..24003024 100644 --- a/integration-tests/models/test_t5_sharded.py +++ b/integration-tests/models/test_t5_sharded.py @@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle): return t5_sharded_handle.client +@pytest.mark.release @pytest.mark.asyncio async def test_t5_sharded(t5_sharded, response_snapshot): response = await t5_sharded.generate( @@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot): assert response == response_snapshot +@pytest.mark.release @pytest.mark.asyncio async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot): responses = await generate_load( diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 45868096..816fa5f3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -413,6 +413,9 @@ struct Args { #[clap(long, env)] otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] cors_allow_origin: Vec, #[clap(long, env)] @@ -489,6 +492,7 @@ fn shard_manager( max_input_tokens: usize, lora_adapters: Option, otlp_endpoint: Option, + otlp_service_name: String, log_level: LevelFilter, status_sender: mpsc::Sender, shutdown: Arc, @@ -554,12 +558,16 @@ fn shard_manager( (None, Some(factor)) => Some((RopeScaling::Linear, factor)), }; - // OpenTelemetry + // OpenTelemetry Endpoint if let Some(otlp_endpoint) = otlp_endpoint { shard_args.push("--otlp-endpoint".to_string()); shard_args.push(otlp_endpoint); } + // OpenTelemetry Service Name + shard_args.push("--otlp-service-name".to_string()); + shard_args.push(otlp_service_name); + // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. shard_args.push("--max-input-tokens".to_string()); shard_args.push(max_input_tokens.to_string()); @@ -598,7 +606,7 @@ fn shard_manager( // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HF_TOKEN".into(), api_token.into())) }; // Detect rope scaling @@ -762,7 +770,10 @@ fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver fn num_cuda_devices() -> Option { let devices = match env::var("CUDA_VISIBLE_DEVICES") { Ok(devices) => devices, - Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, + Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") { + Ok(devices) => devices, + Err(_) => env::var("ZE_AFFINITY_MASK").ok()?, + }, }; let n_devices = devices.split(',').count(); Some(n_devices) @@ -835,9 +846,9 @@ fn find_num_shards( let num_shard = match (sharded, num_shard) { (Some(true), None) => { // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK"); let n_devices = num_cuda_devices() - .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); + .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set"); if n_devices <= 1 { return Err(LauncherError::NotEnoughCUDADevices(format!( "`sharded` is true but only found {n_devices} CUDA devices" @@ -936,7 +947,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HF_TOKEN".into(), api_token.into())) }; // If args.weights_cache_override is some, pass it to the download process @@ -1046,6 +1057,7 @@ fn spawn_shards( let shutdown = shutdown.clone(); let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); + let otlp_service_name = args.otlp_service_name.clone(); let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; @@ -1087,6 +1099,7 @@ fn spawn_shards( max_input_tokens, lora_adapters, otlp_endpoint, + otlp_service_name, max_log_level, status_sender, shutdown, @@ -1220,6 +1233,11 @@ fn spawn_webserver( router_args.push(otlp_endpoint); } + // OpenTelemetry + let otlp_service_name = args.otlp_service_name; + router_args.push("--otlp-service-name".to_string()); + router_args.push(otlp_service_name); + // CORS origins for origin in args.cors_allow_origin.into_iter() { router_args.push("--cors-allow-origin".to_string()); @@ -1240,7 +1258,7 @@ fn spawn_webserver( // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HF_TOKEN".into(), api_token.into())) }; // Parse Compute type diff --git a/router/src/lib.rs b/router/src/lib.rs index bb407c5f..126726c6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -576,7 +576,7 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "text_completion".into(), + object: "chat.completion".into(), created, model, system_fingerprint, @@ -688,7 +688,7 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "text_completion".to_string(), + object: "chat.completion.chunk".to_string(), created, model, system_fingerprint, diff --git a/router/src/main.rs b/router/src/main.rs index c4203dbc..a7caec2e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -65,6 +65,8 @@ struct Args { json_output: bool, #[clap(long, env)] otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] @@ -107,6 +109,7 @@ async fn main() -> Result<(), RouterError> { validation_workers, json_output, otlp_endpoint, + otlp_service_name, cors_allow_origin, ngrok, ngrok_authtoken, @@ -117,7 +120,7 @@ async fn main() -> Result<(), RouterError> { } = args; // Launch Tokio runtime - init_logging(otlp_endpoint, json_output); + init_logging(otlp_endpoint, otlp_service_name, json_output); // Validate args if max_input_tokens >= max_total_tokens { @@ -156,7 +159,9 @@ async fn main() -> Result<(), RouterError> { }); // Parse Huggingface hub token - let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); // Tokenizer instance // This will only be used to validate payloads @@ -367,10 +372,11 @@ async fn main() -> Result<(), RouterError> { /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - otlp_service_name service name to appear in APM /// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) /// - LOG_FORMAT may be TEXT or JSON (default to TEXT) /// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) -fn init_logging(otlp_endpoint: Option, json_output: bool) { +fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { let mut layers = Vec::new(); // STDOUT/STDERR layer @@ -401,7 +407,7 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { trace::config() .with_resource(Resource::new(vec![KeyValue::new( "service.name", - "text-generation-inference.router", + otlp_service_name, )])) .with_sampler(Sampler::AlwaysOn), ) diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py new file mode 100644 index 00000000..8f88b1f8 --- /dev/null +++ b/server/tests/utils/test_weights.py @@ -0,0 +1,1152 @@ +import pytest +import torch +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.layers.marlin import MarlinWeight +from types import SimpleNamespace +from typing import List, Optional, Dict, Union +from pathlib import Path + +dummy_file_system = { + "test_weights": { + "layer.0.weight": torch.tensor( + [ + [1, 2], + [3, 4], + ], + dtype=torch.float32, + ), + }, + "test_weights_2": { + "layer.1337.weight": torch.tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_weights_col_packed": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_multi_weights_col": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_multi_weights_row": { + "weight.weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + }, + "test_get_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + }, + "test_get_multi_weights_row_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_multi_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_packed_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), + "weight.qzeros": torch.tensor( + [ + [0, 1], + [1, 0], + ], + dtype=torch.int32, + ), + "weight.scales": torch.tensor( + [ + [100.0, 100.0], + [100.0, 100.0], + ], + dtype=torch.float16, + ), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([2], dtype=torch.float32), + }, + "test_get_weights_col_packed_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_row_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_col_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_weights_col_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_row_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, + "test_get_multi_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, + "test_get_weights_col_packed_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), + }, +} + + +class MockSlice: + def __init__(self, tensor): + self.tensor = tensor + + def get_shape(self): + return self.tensor.shape + + def __getitem__(self, idx): + return self.tensor[idx] + + +def mock_get_slice(tensor_name, filename): + tensor = dummy_file_system[filename][tensor_name] + return MockSlice(tensor) + + +def mock_handle(filename, device, dtype): + return SimpleNamespace( + get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename) + ) + + +class MockSafeOpen: + def __init__(self, filename, framework, dummy_fs): + self.filename = filename + self.framework = framework + self.dummy_fs = dummy_fs + + def keys(self): + return list(self.dummy_fs[self.filename].keys()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class MockWeights(Weights): + def __init__( + self, + filenames: List[Union[Path, str]], + device, + dtype, + process_group, + dummy_fs, + aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None, + ): + routing = {} + self.dummy_fs = dummy_fs + for filename in filenames: + with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self.prefix = prefix + self._handles = {} + + def _get_handle(self, filename: Union[Path, str]): + if filename in self._handles: + return self._handles[filename] + else: + handle = mock_handle(filename, self.device, self.dtype) + self._handles[filename] = handle + return handle + + def get_shape(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str): + filename, _ = self.get_filename(tensor_name) + handle = self._get_handle(filename) + return handle.get_slice(tensor_name).tensor + + +dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1) + + +def test_weights(): + weights = MockWeights( + [ + "test_weights", + "test_weights_2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + assert weights.get_shape("layer.0.weight") == (2, 2) + assert weights.get_tensor("layer.1337.weight").shape == (2, 4) + + +def test_get_tensor(): + weights = MockWeights( + [ + "test_weights", + "test_weights_2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + assert torch.allclose( + weights.get_tensor("layer.0.weight"), + torch.tensor( + [ + [1, 2], + [3, 4], + ], + dtype=torch.float32, + ), + ) + assert torch.allclose( + weights.get_tensor("layer.1337.weight"), + torch.tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed_block_size(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + block_sizes = 2 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_weights_col_packed_block_size_arr(): + + weights = MockWeights( + [ + "test_get_weights_col_packed", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + block_sizes = [1, 1] + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_multi_weights_col(): + weights = MockWeights( + [ + "test_get_multi_weights_col", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight", "weight"] + quantize = None + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + assert torch.allclose( + w, + torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.float32, + ), + ) + + +def test_get_multi_weights_row(): + weights = MockWeights( + [ + "test_get_multi_weights_row", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = None + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + assert torch.allclose( + w, + torch.tensor( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + dtype=torch.float32, + ), + ) + + +# test_get_weights_col + + +def test_get_weights_col_awq(): + weights = MockWeights( + [ + "test_get_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "awq" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor( + [[100.0, 100.0], [100.0, 100.0]], + dtype=torch.float16, + ), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_gtpq(): + weights = MockWeights( + [ + "test_get_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "gptq" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_col_marlin(): + weights = MockWeights( + [ + "test_get_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_weights_col( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_weights_col_packed + + +def test_get_weights_col_packed_awq(): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "awq" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +@pytest.mark.skip(reason="Review expected functionality") +def test_get_weights_col_packed_exl2(): + weights = MockWeights( + [ + "test_get_weights_col_packed_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + block_sizes = 1 + + w = weights.get_weights_col_packed( + prefix=prefix, + quantize=quantize, + block_sizes=block_sizes, + ) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_weights_col_packed_gptq(): + weights = MockWeights( + [ + "test_get_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "gptq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_weights_col_packed_marlin(): + weights = MockWeights( + [ + "test_get_weights_col_packed_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + print(expected_weight) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_multi_weights_col + + +def test_get_multi_weights_col_awq(): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "awq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_multi_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + try: + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + except ValueError as e: + assert e.args[0] == "get_multi_weights_col is not supported for exl2" + + +def test_get_multi_weights_col_gptq(): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "gptq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_marlin(): + weights = MockWeights( + [ + "test_get_multi_weights_col_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +# test_get_multi_weights_row + + +def test_get_multi_weights_row_awq(): + weights = MockWeights( + [ + "test_get_multi_weights_row_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "awq" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=None, + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_row_exl2(): + weights = MockWeights( + [ + "test_get_multi_weights_row_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + print(w) + + scaled_scale_max = 0.3906 * 256 + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + +def test_get_multi_weights_row_gptq(): + weights = MockWeights( + [ + "test_get_multi_weights_row_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "gptq" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), + g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), + bits=8.0, + groupsize=2.0, + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_row_marlin(): + weights = MockWeights( + [ + "test_get_multi_weights_row_marlin", + ], + device="cpu", + dtype=torch.float16, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_row( + prefix=prefix, + quantize=quantize, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 41744a4d..68ae95dd 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -42,6 +42,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + otlp_service_name: str = "text-generation-inference.server", max_input_tokens: Optional[int] = None, ): if sharded: @@ -76,7 +77,7 @@ def serve( # Setup OpenTelemetry distributed tracing if otlp_endpoint is not None: - setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) + setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e6cb4edf..e74180e7 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -7,7 +7,7 @@ if SYSTEM == "cuda": from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING -elif SYSTEM == "xpu": - from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +elif SYSTEM == "ipex": + from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/ipex.py similarity index 94% rename from server/text_generation_server/layers/attention/xpu.py rename to server/text_generation_server/layers/attention/ipex.py index 8b6cb87b..bfab0119 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,5 +1,6 @@ import intel_extension_for_pytorch as ipex import torch +from text_generation_server.models.flash_causal_lm import BLOCK_SIZE SUPPORTS_WINDOWING = False @@ -56,8 +57,6 @@ def paged_attention( input_lengths: torch.Tensor, max_s: int, ): - query = query.contiguous() - block_size = value_cache.shape[3] return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -67,7 +66,7 @@ def paged_attention( softmax_scale, block_tables, input_lengths, - block_size, + BLOCK_SIZE, max_s, None, ) diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index c4aa6c7d..ce5289f9 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -82,18 +82,20 @@ elif SYSTEM == "rocm": return super().forward(hidden_states), residual -elif SYSTEM == "xpu": +elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - res_out = hidden_states out = ipex.llm.functional.add_layer_norm( - residual, hidden_states, self.weight, self.bias, self.eps, True + residual, + hidden_states, + self.weight, + self.bias, + self.eps, + residual is not None, ) - if residual is not None: - res_out = residual - return out, res_out + return out, residual if residual is not None else hidden_states class FastRMSNorm(nn.Module): @@ -109,19 +111,16 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if SYSTEM == "xpu": - residual_out = hidden_states + if SYSTEM == "ipex": out = ipex.llm.functional.add_rms_norm( residual, hidden_states, self.weight, None, self.variance_epsilon, - True, + residual is not None, ) - if residual is not None: - residual_out = residual - return out, residual_out + return out, residual if residual is not None else hidden_states elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index c2f12189..b14005e6 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -9,7 +9,7 @@ if SYSTEM == "cuda": import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops -elif SYSTEM == "xpu": +elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex @@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "xpu": + elif SYSTEM == "ipex": ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 6005f737..038de258 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -3,6 +3,10 @@ from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex class LayerConcat(torch.nn.Module): @@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + if SYSTEM == "ipex": + ipex.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + else: + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out @@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - torch.distributed.all_gather(world_output, output, group=self.process_group) + if SYSTEM == "ipex": + ipex.distributed.all_gather(world_output, output, group=self.process_group) + else: + torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - torch.distributed.all_reduce(out, group=self.process_group) + if SYSTEM == "ipex": + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) + if SYSTEM == "ipex": + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 56a112e1..9d56e4ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -22,7 +22,7 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM != "xpu": +if SYSTEM != "ipex": from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index a2361b85..2e839d15 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -26,7 +26,7 @@ import numpy as np from torch import nn from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM != "xpu": +if SYSTEM != "ipex": from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 315b6831..f7678762 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -847,26 +847,43 @@ class FlashCausalLM(Model): empty_cache() element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "xpu": + if SYSTEM == "ipex" and device.type == "xpu": x = 1 else: x = BLOCK_SIZE // element_size - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] + if SYSTEM == "ipex" and device == torch.device("cpu"): + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index e9fc471e..323fcafa 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -34,9 +34,13 @@ class FlashGPT2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e5820391..d996b9c3 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -48,9 +48,13 @@ class FlashLlama(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index e9499781..209eca83 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -50,9 +50,13 @@ class BaseFlashMistral(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 4518febd..ac1fd573 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -33,9 +33,13 @@ class FlashNeoXSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 1930a55c..b1f75adc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -34,9 +34,13 @@ class FlashRWSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 71686d48..e1a7b36e 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -37,9 +37,13 @@ class FlashSantacoderSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/tracing.py b/server/text_generation_server/tracing.py index bf03c379..bc7a04ee 100644 --- a/server/text_generation_server/tracing.py +++ b/server/text_generation_server/tracing.py @@ -54,10 +54,8 @@ class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor): ) -def setup_tracing(shard: int, otlp_endpoint: str): - resource = Resource.create( - attributes={"service.name": f"text-generation-inference.server-{shard}"} - ) +def setup_tracing(otlp_service_name: str, otlp_endpoint: str): + resource = Resource.create(attributes={"service.name": otlp_service_name}) span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) span_processor = BatchSpanProcessor(span_exporter) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 3625e6f2..36d63e86 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ import torch from datetime import timedelta from loguru import logger +from text_generation_server.utils.import_utils import SYSTEM # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) @@ -57,14 +58,7 @@ def initialize_torch_distributed(): options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) else: - try: - import oneccl_bindings_for_pytorch - - backend = "ccl" - if os.getenv("CCL_WORKER_COUNT", None) is None: - os.environ["CCL_WORKER_COUNT"] = str(1) - except ImportError: - backend = "gloo" + backend = "gloo" options = None if WORLD_SIZE == 1: @@ -75,13 +69,24 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.distributed.init_process_group( + backend="ccl", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + else: + torch.distributed.init_process_group( + backend=backend, + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) else: logger.warning("torch.distributed is already initialized.") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index d79e36c2..6d921721 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,14 +1,14 @@ import torch from loguru import logger +import subprocess -def is_xpu_available(): +def is_ipex_available(): try: import intel_extension_for_pytorch except ImportError: return False - - return hasattr(torch, "xpu") and torch.xpu.is_available() + return True def get_cuda_free_memory(device, memory_fraction): @@ -19,11 +19,28 @@ def get_cuda_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction): - total_gpu_memory = torch.xpu.get_device_properties(device).total_memory - free_memory = int(total_gpu_memory * 0.5) + total_memory = torch.xpu.get_device_properties(device).total_memory + device_id = device.index + query = f"xpu-smi dump -d {device_id} -m 18 -n 1" + output = subprocess.check_output(query.split()).decode("utf-8").split("\n") + used_memory = float(output[1].split(",")[-1]) * 1024 * 1024 + free_memory = int(total_memory * 0.95 - used_memory) return free_memory +def get_cpu_free_memory(device, memory_fraction): + import psutil + from text_generation_server.utils.dist import WORLD_SIZE + + mem = psutil.virtual_memory() + free_memory = int(mem.available * 0.95 / WORLD_SIZE) + return free_memory + + +def noop(*args, **kwargs): + pass + + SYSTEM = None if torch.version.hip is not None: SYSTEM = "rocm" @@ -35,18 +52,20 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory -elif is_xpu_available(): - SYSTEM = "xpu" - empty_cache = torch.xpu.empty_cache - synchronize = torch.xpu.synchronize - get_free_memory = get_xpu_free_memory +elif is_ipex_available(): + SYSTEM = "ipex" + if hasattr(torch, "xpu") and torch.xpu.is_available(): + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory + else: + empty_cache = noop + synchronize = noop + get_free_memory = get_cpu_free_memory else: SYSTEM = "cpu" - def noop(*args, **kwargs): - pass - empty_cache = noop synchronize = noop - get_free_memory = noop + get_free_memory = get_cpu_free_memory logger.info(f"Detected system {SYSTEM}")