diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index ce1cdc33..c563fa27 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -21,9 +21,11 @@ jobs: build-and-push: outputs: docker_image: ${{ steps.final.outputs.docker_image }} + docker_volume: ${{ steps.final.outputs.docker_volume }} docker_devices: ${{ steps.final.outputs.docker_devices }} runs_on: ${{ steps.final.outputs.runs_on }} label: ${{ steps.final.outputs.label }} + extra_pytest: ${{ steps.final.outputs.extra_pytest }} concurrency: group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true @@ -44,32 +46,39 @@ jobs: cuda) export dockerfile="Dockerfile" export label_extension="" + export docker_volume="/mnt/cache" export docker_devices="" export runs_on="aws-g6-12xl-plus-priv-cache" export platform="" + export extra_pytest="" ;; rocm) export dockerfile="Dockerfile_amd" export label_extension="-rocm" export docker_devices="/dev/kfd,/dev/dri" - # TODO Re-enable when they pass. - # export runs_on="amd-gpu-tgi" - export runs_on="ubuntu-latest" + export docker_volume="/mnt" + export runs_on="amd-gpu-runners" export platform="" + export extra_pytest="-k test_flash_gemma_gptq_load" ;; intel-xpu) export dockerfile="Dockerfile_intel" export label_extension="-intel-xpu" export docker_devices="" + export docker_volume="/mnt/cache" export runs_on="ubuntu-latest" export platform="xpu" + export extra_pytest="" ;; intel-cpu) export dockerfile="Dockerfile_intel" export label_extension="-intel-cpu" - export docker_devices="" - export runs_on="ubuntu-latest" + export docker_devices="none" + export docker_volume="/mnt/cache" + # export runs_on="ubuntu-latest" + export runs_on="aws-highmemory-32-plus-priv" export platform="cpu" + export extra_pytest="-k test_flash_gemma_simple" ;; esac echo $dockerfile @@ -81,8 +90,10 @@ jobs: echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV echo "LABEL=${label_extension}" >> $GITHUB_ENV echo "PLATFORM=${platform}" >> $GITHUB_ENV + echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV + echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v3 @@ -157,16 +168,18 @@ jobs: run: | echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" + echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" + echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT" integration_tests: concurrency: group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true needs: build-and-push + if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' runs-on: group: ${{ needs.build-and-push.outputs.runs_on }} - if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }} steps: @@ -177,15 +190,16 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - name: Install run: | make install-integration-tests - name: Run tests run: | - export DOCKER_VOLUME=/mnt/cache + export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }} export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} + export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}" export HF_TOKEN=${{ secrets.HF_TOKEN }} echo $DOCKER_IMAGE - pytest -s -vv integration-tests ${PYTEST_FLAGS} + pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST} diff --git a/Dockerfile b/Dockerfile index 80e5b681..daeb9309 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json ARG GIT_SHA ARG DOCKER_LABEL +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto @@ -39,7 +40,7 @@ COPY benchmark benchmark COPY router router COPY backends backends COPY launcher launcher -RUN cargo build --profile release-opt +RUN cargo build --profile release-opt --frozen # Python builder # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile diff --git a/Dockerfile_amd b/Dockerfile_amd index 0b059f8c..4bb6407a 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -31,6 +31,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json ARG GIT_SHA ARG DOCKER_LABEL +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto @@ -38,7 +39,7 @@ COPY benchmark benchmark COPY router router COPY backends backends COPY launcher launcher -RUN cargo build --profile release-opt +RUN cargo build --profile release-opt --frozen # Text Generation Inference base image for RoCm FROM rocm/dev-ubuntu-22.04:6.2 AS base diff --git a/Dockerfile_intel b/Dockerfile_intel index 72588b7e..4ad2db5f 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,6 +1,6 @@ ARG PLATFORM=xpu -FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json ARG GIT_SHA ARG DOCKER_LABEL +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto @@ -39,7 +40,7 @@ COPY benchmark benchmark COPY router router COPY backends backends COPY launcher launcher -RUN cargo build --profile release-opt +RUN cargo build --profile release-opt --frozen # Text Generation Inference base image for Intel @@ -52,7 +53,7 @@ ARG MAMBA_VERSION=23.1.0-1 ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba diff --git a/README.md b/README.md index df6912bf..25dbbd43 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ curl localhost:3000/v1/chat/completions \ **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. -**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above. +**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1-rocm --model-id $model` instead of the command above. To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): ``` @@ -150,7 +150,7 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= -docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model +docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model ``` ### A note on Shared Memory (shm) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2925085b..47ef9d71 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -27,3 +27,6 @@ asyncio_mode = "auto" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.isort] +profile = "black" diff --git a/docs/openapi.json b/docs/openapi.json index 67394a14..d1b60f4d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2114,12 +2114,18 @@ "ToolType": { "oneOf": [ { - "type": "object", - "default": null, - "nullable": true + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] }, { - "type": "string" + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] }, { "type": "object", @@ -2131,13 +2137,10 @@ "$ref": "#/components/schemas/FunctionName" } } - }, - { - "type": "object", - "default": null, - "nullable": true } - ] + ], + "description": "Controls which (if any) tool is called by the model.", + "example": "auto" }, "Url": { "type": "object", diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md index ef3a1db7..cf198dbe 100644 --- a/docs/source/basic_tutorials/gated_model_access.md +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -19,6 +19,6 @@ docker run --gpus all \ --shm-size 1g \ -e HF_TOKEN=$token \ -p 8080:80 \ - -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \ + -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 \ --model-id $model ``` diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index b7672a9f..1898b10c 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes ``` 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. @@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4 +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes-nf4 ``` You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). @@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$ TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize gptq +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize gptq ``` Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 655e6f9e..a52baedb 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -97,7 +97,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help +docker run ghcr.io/huggingface/text-generation-inference:2.3.1 --help ``` diff --git a/flake.lock b/flake.lock index 04d386b3..aacdd30e 100644 --- a/flake.lock +++ b/flake.lock @@ -978,11 +978,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1728029332, - "narHash": "sha256-j0RX3a67lvi2PC5w6J5DHTxM+l96J/OV5sAf34IUfUo=", + "lastModified": 1728381423, + "narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "98049f853346ca780b81fee730715c90d33ac2b4", + "rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e", "type": "github" }, "original": { diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 4c8c929f..f24fc079 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -492,6 +492,7 @@ def launcher(event_loop): try: container = client.containers.get(container_name) container.stop() + container.remove() container.wait() except NotFound: pass @@ -514,13 +515,28 @@ def launcher(event_loop): volumes = [f"{DOCKER_VOLUME}:/data"] if DOCKER_DEVICES: - devices = DOCKER_DEVICES.split(",") + if DOCKER_DEVICES.lower() == "none": + devices = [] + else: + devices = DOCKER_DEVICES.strip().split(",") visible = os.getenv("ROCR_VISIBLE_DEVICES") if visible: env["ROCR_VISIBLE_DEVICES"] = visible device_requests = [] + if not devices: + devices = None + elif devices == ["nvidia.com/gpu=all"]: + devices = None + device_requests = [ + docker.types.DeviceRequest( + driver="cdi", + # count=gpu_count, + device_ids=[f"nvidia.com/gpu={i}"], + ) + for i in range(gpu_count) + ] else: - devices = [] + devices = None device_requests = [ docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) ] @@ -540,21 +556,26 @@ def launcher(event_loop): shm_size="1G", ) - yield ContainerLauncherHandle(client, container.name, port) - - if not use_flash_attention: - del env["USE_FLASH_ATTENTION"] - try: - container.stop() - container.wait() - except NotFound: - pass + yield ContainerLauncherHandle(client, container.name, port) - container_output = container.logs().decode("utf-8") - print(container_output, file=sys.stderr) + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] - container.remove() + try: + container.stop() + container.wait() + except NotFound: + pass + + container_output = container.logs().decode("utf-8") + print(container_output, file=sys.stderr) + + finally: + try: + container.remove() + except Exception: + pass if DOCKER_IMAGE is not None: return docker_launcher diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_simple.json similarity index 100% rename from integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json rename to integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_simple.json diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_simple.json similarity index 100% rename from integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json rename to integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_simple.json diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json new file mode 100644 index 00000000..9ca22e10 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json @@ -0,0 +1,104 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.296875, + "text": "What" + }, + { + "id": 349, + "logprob": -0.97216797, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.9658203, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44384766, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.50878906, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.8876953, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15124512, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.16687012, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.8046875, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007205963, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.090026855, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030670166, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json new file mode 100644 index 00000000..38ab7263 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 349, + "logprob": -13.921875, + "text": "is" + }, + { + "id": 3534, + "logprob": -11.2265625, + "text": "deep" + }, + { + "id": 5168, + "logprob": -2.3886719, + "text": "learning" + }, + { + "id": 28804, + "logprob": -4.7109375, + "text": "?" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.5229492, + "special": false, + "text": "Deep" + }, + { + "id": 17504, + "logprob": 0.0, + "special": false, + "text": " Learning" + }, + { + "id": 349, + "logprob": -0.5151367, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": 0.0, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 13253, + "logprob": -1.3359375, + "special": false, + "text": " Machine" + }, + { + "id": 17504, + "logprob": 0.0, + "special": false, + "text": " Learning" + }, + { + "id": 28725, + "logprob": 0.0, + "special": false, + "text": "," + } + ], + "top_tokens": null + }, + "generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning," +} diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json new file mode 100644 index 00000000..329d73ee --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json @@ -0,0 +1,418 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.296875, + "text": "What" + }, + { + "id": 349, + "logprob": -0.97216797, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.9658203, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44384766, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.50878906, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.8876953, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15136719, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030273438, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1665039, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.1776123, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.8076172, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.090148926, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030670166, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.34375, + "text": "What" + }, + { + "id": 349, + "logprob": -0.96728516, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.97265625, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.51220703, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.87402344, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15039062, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1652832, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.81103516, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.08880615, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030612946, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.34375, + "text": "What" + }, + { + "id": 349, + "logprob": -0.96728516, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.97265625, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.51220703, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.87402344, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15039062, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1652832, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.81103516, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.08880615, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030612946, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.34375, + "text": "What" + }, + { + "id": 349, + "logprob": -0.96728516, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.97265625, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.51220703, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.87402344, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15039062, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1652832, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.81103516, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.08880615, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030612946, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + } +] diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json index 0cd3c67f..70b20362 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json @@ -1,38 +1,26 @@ { "choices": [ { - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 0, "logprobs": null, "message": { - "content": null, + "content": "I am an AI assistant", "name": null, "role": "assistant", - "tool_calls": [ - { - "function": { - "arguments": { - "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." - }, - "description": null, - "name": "notify_error" - }, - "id": 0, - "type": "function" - } - ] + "tool_calls": null }, "usage": null } ], - "created": 1712852597, + "created": 1728497062, "id": "", - "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "1.4.5-native", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.3.2-dev0-native", "usage": { - "completion_tokens": 39, - "prompt_tokens": 496, - "total_tokens": 535 + "completion_tokens": 23, + "prompt_tokens": 604, + "total_tokens": 627 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json new file mode 100644 index 00000000..fa208c54 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": " assistant", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1728497531, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json new file mode 100644 index 00000000..72232e17 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": " fans", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1728497461, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py index 7bee8dea..4bd7bd14 100644 --- a/integration-tests/models/test_flash_gemma.py +++ b/integration-tests/models/test_flash_gemma.py @@ -16,7 +16,7 @@ async def flash_gemma(flash_gemma_handle): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_gemma(flash_gemma, response_snapshot): +async def test_flash_gemma_simple(flash_gemma, response_snapshot): response = await flash_gemma.generate( "Test request", max_new_tokens=10, decoder_input_details=True ) diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index c69314ff..bf49dc0b 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -15,7 +15,7 @@ async def flash_llama(flash_llama_handle): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama(flash_llama, response_snapshot): +async def test_flash_llama_simple(flash_llama, response_snapshot): response = await flash_llama.generate( "Test request", max_new_tokens=10, decoder_input_details=True ) diff --git a/integration-tests/models/test_flash_mixtral_awq.py b/integration-tests/models/test_flash_mixtral_awq.py new file mode 100644 index 00000000..ab1e0f00 --- /dev/null +++ b/integration-tests/models/test_flash_mixtral_awq.py @@ -0,0 +1,73 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_mixtral_awq_handle(launcher): + with launcher("casperhansen/mixtral-instruct-awq", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_mixtral_awq(flash_mixtral_awq_handle): + await flash_mixtral_awq_handle.health(300) + return flash_mixtral_awq_handle.client + + +@pytest.mark.asyncio +async def test_flash_mixtral_awq(flash_mixtral_awq, response_snapshot): + response = await flash_mixtral_awq.generate( + "What is deep learning?", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text == "\n\nDeep learning is a subset of machine learning" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mixtral_awq_all_params(flash_mixtral_awq, response_snapshot): + response = await flash_mixtral_awq.generate( + "What is deep learning?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "What is deep learning?\nDeep Learning is a subset of Machine Learning," + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mixtral_awq_load( + flash_mixtral_awq, generate_load, response_snapshot +): + responses = await generate_load( + flash_mixtral_awq, "What is deep learning?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert responses[0].details.generated_tokens == 10 + assert ( + responses[0].generated_text + == "\n\nDeep learning is a subset of machine learning" + ) + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + + assert responses == response_snapshot diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index c337afa1..98e75bb4 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream( ) count = 0 + tool_calls_generated = "" + last_response = None async for response in responses: count += 1 + tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + last_response = response + assert response.choices[0].delta.content is None + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>' + ) assert count == 28 - assert response == response_snapshot + assert last_response == response_snapshot @pytest.mark.asyncio @@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information( messages=[ { "role": "system", - "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "Who are you?", + }, + ], + stream=False, + ) + + assert responses.choices[0].message.tool_calls is None + assert responses.choices[0].message.content == "I am an AI assistant" + + assert responses == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_insufficient_information_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "Who are you?", + }, + ], + stream=True, + ) + + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 5 + assert content_generated == "I am an AI assistant" + assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", }, { "role": "user", "content": "Tell me a story about 3 sea creatures", }, ], - stream=False, + stream=True, ) - assert responses.choices[0].message.content is None + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 62 assert ( - responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" + content_generated + == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans" ) - assert responses == response_snapshot + assert last_response == response_snapshot diff --git a/integration-tests/pyproject.toml b/integration-tests/pyproject.toml index afd57ea7..e49b98fc 100644 --- a/integration-tests/pyproject.toml +++ b/integration-tests/pyproject.toml @@ -13,3 +13,6 @@ pytest = "^7.4.0" pytest-asyncio = "^0.21.1" docker = "^7" numpy = "^1.20" + +[tool.isort] +profile = "black" diff --git a/nix/impure-shell.nix b/nix/impure-shell.nix index a4dad4ba..abed544a 100644 --- a/nix/impure-shell.nix +++ b/nix/impure-shell.nix @@ -1,5 +1,7 @@ { mkShell, + black, + isort, openssl, pkg-config, protobuf, @@ -14,6 +16,8 @@ mkShell { buildInputs = [ + black + isort openssl.dev pkg-config (rust-bin.stable.latest.default.override { diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 1c9d5620..896f4f43 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -355,6 +355,8 @@ pub enum InferError { MissingTemplateVariable(String), #[error("Tool error: {0}")] ToolError(String), + #[error("Stream event serialization error")] + StreamSerializationError(String), } impl InferError { @@ -368,6 +370,7 @@ impl InferError { InferError::TemplateError(_) => "template_error", InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::ToolError(_) => "tool_error", + InferError::StreamSerializationError(_) => "stream_serialization_error", } } } diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 4fe15720..f86205fb 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -31,32 +31,29 @@ impl ToolGrammar { let mut tools = tools.clone(); - // add the notify_error function to the tools - let notify_error = Tool { + // add the no_tool function to the tools + let no_tool = Tool { r#type: "function".to_string(), function: FunctionDefinition { - name: "notify_error".to_string(), - description: Some("Notify an error or issue".to_string()), + name: "no_tool".to_string(), + description: Some("Open ened response with no specific tool selected".to_string()), arguments: json!({ "type": "object", "properties": { - "error": { + "content": { "type": "string", - "description": "The error or issue to notify" + "description": "The response content", } }, - "required": ["error"] + "required": ["content"] }), }, }; - tools.push(notify_error); + tools.push(no_tool); // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![Self::find_tool_by_name(&tools, &name)?] - } - ToolType::Function { function } => { + ToolType::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } ToolType::OneOf => tools.clone(), diff --git a/router/src/lib.rs b/router/src/lib.rs index 0901bafa..b29c9395 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -957,12 +957,18 @@ pub fn default_tool_prompt() -> String { } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] -#[serde(untagged)] +#[schema(example = "auto")] +/// Controls which (if any) tool is called by the model. pub enum ToolType { + /// Means the model can pick between generating a message or calling one or more tools. + #[schema(rename = "auto")] OneOf, - FunctionName(String), - Function { function: FunctionName }, + /// Means the model will not call any tool and instead generates a message. + #[schema(rename = "none")] NoTool, + /// Forces the model to call a specific tool. + #[schema(rename = "function")] + Function(FunctionName), } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] @@ -977,6 +983,7 @@ pub struct ToolChoice(pub Option); #[derive(Deserialize)] #[serde(untagged)] enum ToolTypeDeserializer { + Null, String(String), ToolType(ToolType), } @@ -984,10 +991,11 @@ enum ToolTypeDeserializer { impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { + ToolTypeDeserializer::Null => ToolChoice(None), ToolTypeDeserializer::String(s) => match s.as_str() { "none" => ToolChoice(Some(ToolType::NoTool)), "auto" => ToolChoice(Some(ToolType::OneOf)), - _ => ToolChoice(Some(ToolType::FunctionName(s))), + _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), }, ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } diff --git a/router/src/server.rs b/router/src/server.rs index 73b54321..5e6e6960 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use pyo3::types::IntoPyDict; +use regex::Regex; use serde_json::Value; use std::convert::Infallible; use std::fs::File; @@ -452,12 +453,20 @@ async fn generate_stream( Sse>>, ) { let span = tracing::Span::current(); - let on_message_callback = |stream_token: StreamResponse| { - let event = Event::default(); - event.json_data(stream_token).unwrap() - }; let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await; + generate_stream_internal(infer, compute_type, Json(req), span).await; + + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + while let Some(raw_event) = response_stream.next().await { + yield Ok(raw_event.map_or_else(Event::from, |token| { + Event::default() + .json_data(token) + .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()) + })); + } + }; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); (headers, sse) } @@ -466,9 +475,11 @@ async fn generate_stream_internal( infer: Infer, ComputeType(compute_type): ComputeType, Json(req): Json, - on_message_callback: impl Fn(StreamResponse) -> Event, span: tracing::Span, -) -> (HeaderMap, impl Stream>) { +) -> ( + HeaderMap, + impl Stream>, +) { let start_time = Instant::now(); metrics::counter!("tgi_request_count").increment(1); @@ -500,12 +511,12 @@ async fn generate_stream_internal( let err = InferError::from(ValidationError::BestOfStream); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives @@ -535,8 +546,7 @@ async fn generate_stream_internal( generated_text: None, details: None, }; - let event = on_message_callback(stream_token); - yield Ok(event); + yield Ok(stream_token); } // Yield event for last token and compute timings InferStreamResponse::End { @@ -600,9 +610,7 @@ async fn generate_stream_internal( details }; - - let event = on_message_callback(stream_token); - yield Ok(event); + yield Ok(stream_token); break; } } @@ -610,7 +618,7 @@ async fn generate_stream_internal( // yield error Err(err) => { error = true; - yield Ok(Event::from(err)); + yield Err(err); break; } } @@ -619,7 +627,7 @@ async fn generate_stream_internal( // yield error Err(err) => { error = true; - yield Ok(Event::from(err)); + yield Err(err); } } // Check if generation reached the end @@ -628,7 +636,7 @@ async fn generate_stream_internal( let err = InferError::IncompleteGenerationStream; metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } } }; @@ -771,75 +779,85 @@ async fn completions( // Create a future for each generate_stream_internal call. let generate_future = async move { - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); - - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let message = match stream_token.details { - Some(details) => { - let completion_tokens = details.generated_tokens; - let prompt_tokens = details.input_length; - let total_tokens = prompt_tokens + completion_tokens; - - Completion::Final(CompletionFinal { - id: String::new(), - created: current_time, - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - choices: vec![CompletionComplete { - finish_reason: details.finish_reason.to_string(), - index: index as u32, - logprobs: None, - text: stream_token.token.text, - }], - usage: Usage { - prompt_tokens, - completion_tokens, - total_tokens, - }, - }) - } - None => Completion::Chunk(Chunk { - id: String::new(), - created: current_time, - choices: vec![CompletionComplete { - finish_reason: String::new(), - index: index as u32, - logprobs: None, - text: stream_token.token.text, - }], - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - }), - }; - - event - .json_data(message) - .unwrap_or_else(|_e| Event::default()) - }; - let (header_tx, header_rx) = oneshot::channel(); let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); tokio::spawn(async move { - let (header_map, sse) = generate_stream_internal( + let (headers, response_stream) = generate_stream_internal( infer_clone.clone(), compute_type_clone.clone(), Json(generate_request), - on_message_callback, span_clone.clone(), ) .await; + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + + while let Some(stream_token) = response_stream.next().await { + match stream_token { + Ok(stream_token) => { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let message = match stream_token.details { + Some(details) => { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + + Completion::Final(CompletionFinal { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }) + } + None => Completion::Chunk(Chunk { + id: String::new(), + created: current_time, + choices: vec![CompletionComplete { + finish_reason: String::new(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + }), + }; + + let event = event + .json_data(message) + .unwrap_or_else(|_e| Event::default()); + + yield Ok(event); + } + Err(err) => yield Ok(Event::from(err)), + } + } + }; + // send and dont wait for response - let _ = header_tx.send(header_map); + let _ = header_tx.send(headers); // pin an emit messages to the sse_tx - let mut sse = Box::pin(sse); + let mut sse = Box::pin(response_stream); while let Some(event) = sse.next().await { if sse_tx.send(event).is_err() { tracing::error!("Failed to send event. Receiver dropped."); @@ -1072,6 +1090,84 @@ async fn completions( } } +enum StreamState { + Buffering, + BufferTrailing, + Content { skip_close_quote: bool }, +} + +/// Convert a StreamResponse into an Event to be sent over SSE +fn create_event_from_stream_token( + stream_token: &StreamResponse, + logprobs: bool, + stream_options: Option, + inner_using_tools: bool, + system_fingerprint: String, + model_id: String, +) -> Event { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let logprobs = logprobs.then(|| { + ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone())) + }); + + // replace the content with the tool calls if grammar is present + let (content, tool_calls) = if inner_using_tools { + (None, Some(vec![stream_token.token.text.clone()])) + } else { + let content = if !stream_token.token.special { + Some(stream_token.token.text.clone()) + } else { + None + }; + + (content, None) + }; + + let (usage, finish_reason) = match &stream_token.details { + Some(details) => { + let usage = if stream_options + .as_ref() + .map(|s| s.include_usage) + .unwrap_or(false) + { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Some(Usage { + completion_tokens, + prompt_tokens, + total_tokens, + }) + } else { + None + }; + (usage, Some(details.finish_reason.format(true))) + } + None => (None, None), + }; + + let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + finish_reason, + usage, + )); + + event.json_data(chat_complete).unwrap_or_else(|e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }) +} + /// Generate tokens #[utoipa::path( post, @@ -1128,88 +1224,135 @@ async fn chat_completions( // static values that will be returned in all cases let model_id = info.model_id.clone(); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); - // switch on stream if stream { - // pass this callback to the stream generation and build the required event structure - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); + let (headers, response_stream) = + generate_stream_internal(infer, compute_type, Json(generate_request), span).await; - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let logprobs = logprobs.then(|| { - ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) - }); - - // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if using_tools { - (None, Some(vec![stream_token.token.text])) - } else { - let content = if !stream_token.token.special { - Some(stream_token.token.text) - } else { - None - }; - - (content, None) - }; - - let (usage, finish_reason) = match stream_token.details { - Some(details) => { - let usage = if stream_options - .as_ref() - .map(|s| s.include_usage) - .unwrap_or(false) - { - let completion_tokens = details.generated_tokens; - let prompt_tokens = details.input_length; - let total_tokens = prompt_tokens + completion_tokens; - Some(Usage { - completion_tokens, - prompt_tokens, - total_tokens, - }) - } else { - None - }; - (usage, Some(details.finish_reason.format(true))) - } - None => (None, None), - }; - event - .json_data(CompletionType::ChatCompletionChunk( - ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - finish_reason, - usage, - ), + // regex to match any function name + let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { + Ok(regex) => regex, + Err(e) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to compile regex: {}", e), + error_type: "regex".to_string(), + }), )) - .unwrap_or_else(|e| { - println!("Failed to serialize ChatCompletionChunk: {:?}", e); - Event::default() - }) + } }; - let (headers, response_stream) = generate_stream_internal( - infer, - compute_type, - Json(generate_request), - on_message_callback, - span, - ) - .await; + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + let mut buffer = Vec::new(); + let mut json_buffer = String::new(); + let mut state = if using_tools { + StreamState::Buffering + } else { + StreamState::Content { + skip_close_quote: false, + } + }; + let mut response_as_tool = using_tools; + while let Some(result) = response_stream.next().await { + if let Ok(stream_token) = result { + let token_text = &stream_token.token.text.clone(); + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "no_tool" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } + } + } + } + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"content\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + // keep capturing until we find the infix text + match json_buffer.find(infix_text) { + Some(content_key_index) => { + json_buffer = + json_buffer[content_key_index + infix_text.len()..].to_string(); + } + None => { + continue; + } + } + // if there is leftover text after removing the infix text, we need to send it + if !json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); + } + // cleanup the buffers + buffer.clear(); + json_buffer.clear(); + state = StreamState::Content { + skip_close_quote: true, + }; + } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } - let response_stream = response_stream.chain(futures::stream::once(async { - Ok(Event::default().data("[DONE]")) - })); + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + + yield Ok::(event); + } + } + } + } + yield Ok::(Event::default().data("[DONE]")); + }; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) @@ -1246,17 +1389,33 @@ async fn chat_completions( if let Value::Object(ref mut props) = arguments { props.remove("_name"); } - - let tool_calls = vec![ToolCall { - id: "0".to_string(), - r#type: "function".to_string(), - function: FunctionDefinition { - description: None, - name, - arguments, - }, - }]; - (Some(tool_calls), None) + match name.as_str() { + "no_tool" => { + // parse the content message + let content_message = arguments + .get("content") + .and_then(Value::as_str) + .ok_or_else(|| { + InferError::ToolError( + "No `content` found in generated text".to_string(), + ) + })? + .to_string(); + (None, Some(content_message)) + } + _ => { + let tool_calls = vec![ToolCall { + id: "0".to_string(), + r#type: "function".to_string(), + function: FunctionDefinition { + description: None, + name, + arguments, + }, + }]; + (Some(tool_calls), None) + } + } } else { (None, Some(generation.generated_text)) }; @@ -2323,6 +2482,7 @@ impl From for (StatusCode, Json) { InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR, }; ( @@ -2500,8 +2660,8 @@ mod tests { ); assert!(result.is_ok()); - let (inputs, _grammar, using_tools) = result.unwrap(); + let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); } } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f392b161..12d58532 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] # Released on: June 13, 2024 # https://releases.rs/docs/1.79.0/ -channel = "1.80.0" +channel = "1.80.1" components = ["rustfmt", "clippy"] diff --git a/server/poetry.lock b/server/poetry.lock index 64e45765..08f74999 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1269,12 +1269,12 @@ files = [ [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:f28fd2a56c3ac7bfe74bc44cc7c8c0791a2644ad689b084ea4ed6decb7f41c25"}, ] [package.dependencies] @@ -1284,16 +1284,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:db475948fd9f7a8647aa3f73256ff4d3bb111425305bcd0b0d3559ccc75b8937"}, ] [package.dependencies] @@ -1303,16 +1303,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:364be07c06aafbab1f51d9e26d9a4ff658defe1462a4c645abaf7b895ed163a8"}, ] [package.dependencies] @@ -1322,16 +1322,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:81e7fa25fb5ed5336f5151994f5e3f600df7e166fe013576968c59415e442894"}, ] [package.dependencies] @@ -1341,7 +1341,7 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mpmath" @@ -3402,11 +3402,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] diff --git a/server/pyproject.toml b/server/pyproject.toml index ef67deb1..6ea4718d 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -47,10 +47,10 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] rich = "^13.7.1" @@ -82,3 +82,6 @@ requires = [ "poetry-core>=1.0.0", ] build-backend = "poetry.core.masonry.api" + +[tool.isort] +profile = "black" diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index ced4b5b4..3960c954 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -24,10 +24,8 @@ class KVCache: ): """Construct the key-value cache for a layer.""" - if ( - dtype == torch.float8_e5m2 - and (ATTENTION != "flashinfer" - or SYSTEM != "cuda") + if dtype == torch.float8_e5m2 and ( + ATTENTION != "flashinfer" or SYSTEM != "cuda" ): raise ValueError( "float8_e5m2 KV cache is currently only supported for flashinfer on CUDA" diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 0a785d94..7245431f 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -43,7 +43,7 @@ def can_use_gptq_marlin( and quant_method in {"awq", "gptq"} and bits in GPTQ_MARLIN_BITS and groupsize in GPTQ_MARLIN_GROUP_SIZES - # We only suppord asymmetric quantization for AWQ. + # We only support asymmetric quantization for AWQ. and (sym or quant_method == "awq") ) diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index 2c46ca02..558d9ed9 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -210,11 +210,17 @@ class SparseMoELayer(nn.Module): and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader): cls = UnquantizedSparseMoELayer - elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: + elif isinstance( + weights.loader, GPTQMarlinWeightsLoader + ) and can_use_marlin_moe_gemm( + quant_method=weights.loader.quant_method, + quantize=weights.loader.quantize, + sym=weights.loader.sym, + ): cls = GPTQMarlinSparseMoELayer else: raise ValueError( - f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" + f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights" ) log_once( diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py index 3217cdc2..3d4ca9d8 100644 --- a/server/text_generation_server/layers/moe/gptq_marlin.py +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -34,9 +34,10 @@ def can_use_marlin_moe_gemm( SYSTEM == "cuda" and fused_marlin_moe is not None and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" - and sym + and quantize in {"awq", "gptq"} + and quant_method in {"awq", "gptq"} + # We only support asymmetric quantization for AWQ. + and (sym or quant_method == "awq") ) @@ -72,10 +73,15 @@ class GPTQMarlinSparseMoELayer(nn.Module): super().__init__() if not ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym + isinstance(weights.loader, GPTQMarlinWeightsLoader) + and can_use_marlin_moe_gemm( + quant_method=weights.loader.quant_method, + quantize=weights.loader.quantize, + sym=weights.loader.sym, + ) ): raise ValueError( - f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" + f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported" ) assert (n_expert_group is None) == ( @@ -102,17 +108,24 @@ class GPTQMarlinSparseMoELayer(nn.Module): def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: return fused_marlin_moe( - x, + hidden_states=x, w1=self.gate_up_proj.qweight, w2=self.down_proj.qweight, - g_idx1=self.gate_up_proj.g_idx, - g_idx2=self.down_proj.g_idx, - perm1=self.gate_up_proj.perm, - perm2=self.down_proj.perm, w1_scale=self.gate_up_proj.scales, w2_scale=self.down_proj.scales, - is_full_k1=self.gate_up_proj.is_full_k, - is_full_k2=self.down_proj.is_full_k, + w1_zeros=( + self.gate_up_proj.qzeros + if self.gate_up_proj.qzeros.numel() > 0 + else None + ), + w2_zeros=( + self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None + ), + g_idx1=self.gate_up_proj.g_idx, + g_idx2=self.down_proj.g_idx, + sort_indices1=self.gate_up_proj.perm, + sort_indices2=self.down_proj.perm, + is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k, gating_output=gating_output, topk=self.topk, renormalize=self.renormalize,