mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into sliding_window
This commit is contained in:
commit
f213012b08
32
.github/workflows/build.yaml
vendored
32
.github/workflows/build.yaml
vendored
@ -21,9 +21,11 @@ jobs:
|
||||
build-and-push:
|
||||
outputs:
|
||||
docker_image: ${{ steps.final.outputs.docker_image }}
|
||||
docker_volume: ${{ steps.final.outputs.docker_volume }}
|
||||
docker_devices: ${{ steps.final.outputs.docker_devices }}
|
||||
runs_on: ${{ steps.final.outputs.runs_on }}
|
||||
label: ${{ steps.final.outputs.label }}
|
||||
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
@ -44,32 +46,39 @@ jobs:
|
||||
cuda)
|
||||
export dockerfile="Dockerfile"
|
||||
export label_extension=""
|
||||
export docker_volume="/mnt/cache"
|
||||
export docker_devices=""
|
||||
export runs_on="aws-g6-12xl-plus-priv-cache"
|
||||
export platform=""
|
||||
export extra_pytest=""
|
||||
;;
|
||||
rocm)
|
||||
export dockerfile="Dockerfile_amd"
|
||||
export label_extension="-rocm"
|
||||
export docker_devices="/dev/kfd,/dev/dri"
|
||||
# TODO Re-enable when they pass.
|
||||
# export runs_on="amd-gpu-tgi"
|
||||
export runs_on="ubuntu-latest"
|
||||
export docker_volume="/mnt"
|
||||
export runs_on="amd-gpu-runners"
|
||||
export platform=""
|
||||
export extra_pytest="-k test_flash_gemma_gptq_load"
|
||||
;;
|
||||
intel-xpu)
|
||||
export dockerfile="Dockerfile_intel"
|
||||
export label_extension="-intel-xpu"
|
||||
export docker_devices=""
|
||||
export docker_volume="/mnt/cache"
|
||||
export runs_on="ubuntu-latest"
|
||||
export platform="xpu"
|
||||
export extra_pytest=""
|
||||
;;
|
||||
intel-cpu)
|
||||
export dockerfile="Dockerfile_intel"
|
||||
export label_extension="-intel-cpu"
|
||||
export docker_devices=""
|
||||
export runs_on="ubuntu-latest"
|
||||
export docker_devices="none"
|
||||
export docker_volume="/mnt/cache"
|
||||
# export runs_on="ubuntu-latest"
|
||||
export runs_on="aws-highmemory-32-plus-priv"
|
||||
export platform="cpu"
|
||||
export extra_pytest="-k test_flash_gemma_simple"
|
||||
;;
|
||||
esac
|
||||
echo $dockerfile
|
||||
@ -81,8 +90,10 @@ jobs:
|
||||
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
|
||||
echo "LABEL=${label_extension}" >> $GITHUB_ENV
|
||||
echo "PLATFORM=${platform}" >> $GITHUB_ENV
|
||||
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
|
||||
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
||||
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
||||
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
|
||||
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
||||
- name: Initialize Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@ -157,16 +168,18 @@ jobs:
|
||||
run: |
|
||||
echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
||||
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
|
||||
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
||||
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
|
||||
integration_tests:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: build-and-push
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
env:
|
||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
|
||||
steps:
|
||||
@ -177,15 +190,16 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install
|
||||
run: |
|
||||
make install-integration-tests
|
||||
- name: Run tests
|
||||
run: |
|
||||
export DOCKER_VOLUME=/mnt/cache
|
||||
export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}
|
||||
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
||||
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
||||
export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}"
|
||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||
echo $DOCKER_IMAGE
|
||||
pytest -s -vv integration-tests ${PYTEST_FLAGS}
|
||||
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
@ -39,7 +40,7 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt --frozen
|
||||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
@ -31,6 +31,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
@ -38,7 +39,7 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt --frozen
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||
|
@ -1,6 +1,6 @@
|
||||
ARG PLATFORM=xpu
|
||||
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
@ -39,7 +40,7 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt --frozen
|
||||
|
||||
|
||||
# Text Generation Inference base image for Intel
|
||||
@ -52,7 +53,7 @@ ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
|
@ -120,7 +120,7 @@ curl localhost:3000/v1/chat/completions \
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
@ -150,7 +150,7 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
|
@ -27,3 +27,6 @@ asyncio_mode = "auto"
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
@ -2114,12 +2114,18 @@
|
||||
"ToolType": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"default": null,
|
||||
"nullable": true
|
||||
"type": "string",
|
||||
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
||||
"enum": [
|
||||
"auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"description": "Means the model will not call any tool and instead generates a message.",
|
||||
"enum": [
|
||||
"none"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
@ -2131,13 +2137,10 @@
|
||||
"$ref": "#/components/schemas/FunctionName"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"default": null,
|
||||
"nullable": true
|
||||
}
|
||||
]
|
||||
],
|
||||
"description": "Controls which (if any) tool is called by the model.",
|
||||
"example": "auto"
|
||||
},
|
||||
"Url": {
|
||||
"type": "object",
|
||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HF_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
|
||||
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize gptq
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize gptq
|
||||
```
|
||||
|
||||
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
||||
|
@ -97,7 +97,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.3.1 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -978,11 +978,11 @@
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1728029332,
|
||||
"narHash": "sha256-j0RX3a67lvi2PC5w6J5DHTxM+l96J/OV5sAf34IUfUo=",
|
||||
"lastModified": 1728381423,
|
||||
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "98049f853346ca780b81fee730715c90d33ac2b4",
|
||||
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
@ -492,6 +492,7 @@ def launcher(event_loop):
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
container.stop()
|
||||
container.remove()
|
||||
container.wait()
|
||||
except NotFound:
|
||||
pass
|
||||
@ -514,13 +515,28 @@ def launcher(event_loop):
|
||||
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||
|
||||
if DOCKER_DEVICES:
|
||||
devices = DOCKER_DEVICES.split(",")
|
||||
if DOCKER_DEVICES.lower() == "none":
|
||||
devices = []
|
||||
else:
|
||||
devices = DOCKER_DEVICES.strip().split(",")
|
||||
visible = os.getenv("ROCR_VISIBLE_DEVICES")
|
||||
if visible:
|
||||
env["ROCR_VISIBLE_DEVICES"] = visible
|
||||
device_requests = []
|
||||
if not devices:
|
||||
devices = None
|
||||
elif devices == ["nvidia.com/gpu=all"]:
|
||||
devices = None
|
||||
device_requests = [
|
||||
docker.types.DeviceRequest(
|
||||
driver="cdi",
|
||||
# count=gpu_count,
|
||||
device_ids=[f"nvidia.com/gpu={i}"],
|
||||
)
|
||||
for i in range(gpu_count)
|
||||
]
|
||||
else:
|
||||
devices = []
|
||||
devices = None
|
||||
device_requests = [
|
||||
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||
]
|
||||
@ -540,21 +556,26 @@ def launcher(event_loop):
|
||||
shm_size="1G",
|
||||
)
|
||||
|
||||
yield ContainerLauncherHandle(client, container.name, port)
|
||||
|
||||
if not use_flash_attention:
|
||||
del env["USE_FLASH_ATTENTION"]
|
||||
|
||||
try:
|
||||
container.stop()
|
||||
container.wait()
|
||||
except NotFound:
|
||||
pass
|
||||
yield ContainerLauncherHandle(client, container.name, port)
|
||||
|
||||
container_output = container.logs().decode("utf-8")
|
||||
print(container_output, file=sys.stderr)
|
||||
if not use_flash_attention:
|
||||
del env["USE_FLASH_ATTENTION"]
|
||||
|
||||
container.remove()
|
||||
try:
|
||||
container.stop()
|
||||
container.wait()
|
||||
except NotFound:
|
||||
pass
|
||||
|
||||
container_output = container.logs().decode("utf-8")
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
finally:
|
||||
try:
|
||||
container.remove()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if DOCKER_IMAGE is not None:
|
||||
return docker_launcher
|
||||
|
@ -0,0 +1,104 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.296875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.97216797,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.9658203,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44384766,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.50878906,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8876953,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15124512,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.16687012,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.8046875,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007205963,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.090026855,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030670166,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -13.921875,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -11.2265625,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -2.3886719,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -4.7109375,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 17504,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.5151367,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 13253,
|
||||
"logprob": -1.3359375,
|
||||
"special": false,
|
||||
"text": " Machine"
|
||||
},
|
||||
{
|
||||
"id": 17504,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 28725,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||
}
|
@ -0,0 +1,418 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.296875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.97216797,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.9658203,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44384766,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.50878906,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8876953,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15136719,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030273438,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1665039,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.1776123,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.8076172,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.090148926,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030670166,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.34375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.96728516,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.97265625,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.51220703,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87402344,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15039062,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1652832,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.81103516,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.08880615,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030612946,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.34375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.96728516,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.97265625,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.51220703,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87402344,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15039062,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1652832,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.81103516,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.08880615,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030612946,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.34375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.96728516,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.97265625,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.51220703,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87402344,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15039062,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1652832,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.81103516,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.08880615,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030612946,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
}
|
||||
]
|
@ -1,38 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"content": "I am an AI assistant",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": null,
|
||||
"name": "notify_error"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1712852597,
|
||||
"created": 1728497062,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 39,
|
||||
"prompt_tokens": 496,
|
||||
"total_tokens": 535
|
||||
"completion_tokens": 23,
|
||||
"prompt_tokens": 604,
|
||||
"total_tokens": 627
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,20 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " assistant",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1728497531,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": null
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " fans",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1728497461,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": null
|
||||
}
|
@ -16,7 +16,7 @@ async def flash_gemma(flash_gemma_handle):
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||
async def test_flash_gemma_simple(flash_gemma, response_snapshot):
|
||||
response = await flash_gemma.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
@ -15,7 +15,7 @@ async def flash_llama(flash_llama_handle):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama(flash_llama, response_snapshot):
|
||||
async def test_flash_llama_simple(flash_llama, response_snapshot):
|
||||
response = await flash_llama.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
73
integration-tests/models/test_flash_mixtral_awq.py
Normal file
73
integration-tests/models/test_flash_mixtral_awq.py
Normal file
@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_mixtral_awq_handle(launcher):
|
||||
with launcher("casperhansen/mixtral-instruct-awq", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_mixtral_awq(flash_mixtral_awq_handle):
|
||||
await flash_mixtral_awq_handle.health(300)
|
||||
return flash_mixtral_awq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_awq(flash_mixtral_awq, response_snapshot):
|
||||
response = await flash_mixtral_awq.generate(
|
||||
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text == "\n\nDeep learning is a subset of machine learning"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_awq_all_params(flash_mixtral_awq, response_snapshot):
|
||||
response = await flash_mixtral_awq.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_awq_load(
|
||||
flash_mixtral_awq, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_mixtral_awq, "What is deep learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert responses[0].details.generated_tokens == 10
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== "\n\nDeep learning is a subset of machine learning"
|
||||
)
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
|
||||
assert responses == response_snapshot
|
@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
)
|
||||
|
||||
count = 0
|
||||
tool_calls_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||
last_response = response
|
||||
assert response.choices[0].delta.content is None
|
||||
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
|
||||
)
|
||||
assert count == 28
|
||||
assert response == response_snapshot
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.tool_calls is None
|
||||
assert responses.choices[0].message.content == "I am an AI assistant"
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
content_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
content_generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 5
|
||||
assert content_generated == "I am an AI assistant"
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a story about 3 sea creatures",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.content is None
|
||||
count = 0
|
||||
content_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
content_generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 62
|
||||
assert (
|
||||
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||
content_generated
|
||||
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
||||
)
|
||||
assert responses == response_snapshot
|
||||
assert last_response == response_snapshot
|
||||
|
@ -13,3 +13,6 @@ pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
docker = "^7"
|
||||
numpy = "^1.20"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
@ -1,5 +1,7 @@
|
||||
{
|
||||
mkShell,
|
||||
black,
|
||||
isort,
|
||||
openssl,
|
||||
pkg-config,
|
||||
protobuf,
|
||||
@ -14,6 +16,8 @@
|
||||
mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
black
|
||||
isort
|
||||
openssl.dev
|
||||
pkg-config
|
||||
(rust-bin.stable.latest.default.override {
|
||||
|
@ -355,6 +355,8 @@ pub enum InferError {
|
||||
MissingTemplateVariable(String),
|
||||
#[error("Tool error: {0}")]
|
||||
ToolError(String),
|
||||
#[error("Stream event serialization error")]
|
||||
StreamSerializationError(String),
|
||||
}
|
||||
|
||||
impl InferError {
|
||||
@ -368,6 +370,7 @@ impl InferError {
|
||||
InferError::TemplateError(_) => "template_error",
|
||||
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
||||
InferError::ToolError(_) => "tool_error",
|
||||
InferError::StreamSerializationError(_) => "stream_serialization_error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,32 +31,29 @@ impl ToolGrammar {
|
||||
|
||||
let mut tools = tools.clone();
|
||||
|
||||
// add the notify_error function to the tools
|
||||
let notify_error = Tool {
|
||||
// add the no_tool function to the tools
|
||||
let no_tool = Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "notify_error".to_string(),
|
||||
description: Some("Notify an error or issue".to_string()),
|
||||
name: "no_tool".to_string(),
|
||||
description: Some("Open ened response with no specific tool selected".to_string()),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
"description": "The response content",
|
||||
}
|
||||
},
|
||||
"required": ["error"]
|
||||
"required": ["content"]
|
||||
}),
|
||||
},
|
||||
};
|
||||
tools.push(notify_error);
|
||||
tools.push(no_tool);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
ToolType::Function(function) => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools.clone(),
|
||||
|
@ -957,12 +957,18 @@ pub fn default_tool_prompt() -> String {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
#[schema(example = "auto")]
|
||||
/// Controls which (if any) tool is called by the model.
|
||||
pub enum ToolType {
|
||||
/// Means the model can pick between generating a message or calling one or more tools.
|
||||
#[schema(rename = "auto")]
|
||||
OneOf,
|
||||
FunctionName(String),
|
||||
Function { function: FunctionName },
|
||||
/// Means the model will not call any tool and instead generates a message.
|
||||
#[schema(rename = "none")]
|
||||
NoTool,
|
||||
/// Forces the model to call a specific tool.
|
||||
#[schema(rename = "function")]
|
||||
Function(FunctionName),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
||||
@ -977,6 +983,7 @@ pub struct ToolChoice(pub Option<ToolType>);
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ToolTypeDeserializer {
|
||||
Null,
|
||||
String(String),
|
||||
ToolType(ToolType),
|
||||
}
|
||||
@ -984,10 +991,11 @@ enum ToolTypeDeserializer {
|
||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||
fn from(value: ToolTypeDeserializer) -> Self {
|
||||
match value {
|
||||
ToolTypeDeserializer::Null => ToolChoice(None),
|
||||
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||
"none" => ToolChoice(Some(ToolType::NoTool)),
|
||||
"auto" => ToolChoice(Some(ToolType::OneOf)),
|
||||
_ => ToolChoice(Some(ToolType::FunctionName(s))),
|
||||
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
|
||||
},
|
||||
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType};
|
||||
use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::types::IntoPyDict;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
@ -452,12 +453,20 @@ async fn generate_stream(
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let span = tracing::Span::current();
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
|
||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
while let Some(raw_event) = response_stream.next().await {
|
||||
yield Ok(raw_event.map_or_else(Event::from, |token| {
|
||||
Event::default()
|
||||
.json_data(token)
|
||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
(headers, sse)
|
||||
}
|
||||
@ -466,9 +475,11 @@ async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
span: tracing::Span,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
) -> (
|
||||
HeaderMap,
|
||||
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||
) {
|
||||
let start_time = Instant::now();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
@ -500,12 +511,12 @@ async fn generate_stream_internal(
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
} else if req.parameters.decoder_input_details {
|
||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
@ -535,8 +546,7 @@ async fn generate_stream_internal(
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
yield Ok(stream_token);
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
@ -600,9 +610,7 @@ async fn generate_stream_internal(
|
||||
details
|
||||
};
|
||||
|
||||
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
yield Ok(stream_token);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -610,7 +618,7 @@ async fn generate_stream_internal(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -619,7 +627,7 @@ async fn generate_stream_internal(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
}
|
||||
}
|
||||
// Check if generation reached the end
|
||||
@ -628,7 +636,7 @@ async fn generate_stream_internal(
|
||||
let err = InferError::IncompleteGenerationStream;
|
||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -771,75 +779,85 @@ async fn completions(
|
||||
|
||||
// Create a future for each generate_stream_internal call.
|
||||
let generate_future = async move {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let message = match stream_token.details {
|
||||
Some(details) => {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
Completion::Final(CompletionFinal {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
None => Completion::Chunk(Chunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: String::new(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(message)
|
||||
.unwrap_or_else(|_e| Event::default())
|
||||
};
|
||||
|
||||
let (header_tx, header_rx) = oneshot::channel();
|
||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (header_map, sse) = generate_stream_internal(
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer_clone.clone(),
|
||||
compute_type_clone.clone(),
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span_clone.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
|
||||
while let Some(stream_token) = response_stream.next().await {
|
||||
match stream_token {
|
||||
Ok(stream_token) => {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let message = match stream_token.details {
|
||||
Some(details) => {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
Completion::Final(CompletionFinal {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
None => Completion::Chunk(Chunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: String::new(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
let event = event
|
||||
.json_data(message)
|
||||
.unwrap_or_else(|_e| Event::default());
|
||||
|
||||
yield Ok(event);
|
||||
}
|
||||
Err(err) => yield Ok(Event::from(err)),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// send and dont wait for response
|
||||
let _ = header_tx.send(header_map);
|
||||
let _ = header_tx.send(headers);
|
||||
|
||||
// pin an emit messages to the sse_tx
|
||||
let mut sse = Box::pin(sse);
|
||||
let mut sse = Box::pin(response_stream);
|
||||
while let Some(event) = sse.next().await {
|
||||
if sse_tx.send(event).is_err() {
|
||||
tracing::error!("Failed to send event. Receiver dropped.");
|
||||
@ -1072,6 +1090,84 @@ async fn completions(
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamState {
|
||||
Buffering,
|
||||
BufferTrailing,
|
||||
Content { skip_close_quote: bool },
|
||||
}
|
||||
|
||||
/// Convert a StreamResponse into an Event to be sent over SSE
|
||||
fn create_event_from_stream_token(
|
||||
stream_token: &StreamResponse,
|
||||
logprobs: bool,
|
||||
stream_options: Option<StreamOptions>,
|
||||
inner_using_tools: bool,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
) -> Event {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
let (usage, finish_reason) = match &stream_token.details {
|
||||
Some(details) => {
|
||||
let usage = if stream_options
|
||||
.as_ref()
|
||||
.map(|s| s.include_usage)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
Some(Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(usage, Some(details.finish_reason.format(true)))
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
));
|
||||
|
||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -1128,88 +1224,135 @@ async fn chat_completions(
|
||||
// static values that will be returned in all cases
|
||||
let model_id = info.model_id.clone();
|
||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
|
||||
// switch on stream
|
||||
if stream {
|
||||
// pass this callback to the stream generation and build the required event structure
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if using_tools {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
let (usage, finish_reason) = match stream_token.details {
|
||||
Some(details) => {
|
||||
let usage = if stream_options
|
||||
.as_ref()
|
||||
.map(|s| s.include_usage)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
Some(Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(usage, Some(details.finish_reason.format(true)))
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
event
|
||||
.json_data(CompletionType::ChatCompletionChunk(
|
||||
ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
),
|
||||
// regex to match any function name
|
||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||
Ok(regex) => regex,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to compile regex: {}", e),
|
||||
error_type: "regex".to_string(),
|
||||
}),
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span,
|
||||
)
|
||||
.await;
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
let mut buffer = Vec::new();
|
||||
let mut json_buffer = String::new();
|
||||
let mut state = if using_tools {
|
||||
StreamState::Buffering
|
||||
} else {
|
||||
StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
}
|
||||
};
|
||||
let mut response_as_tool = using_tools;
|
||||
while let Some(result) = response_stream.next().await {
|
||||
if let Ok(stream_token) = result {
|
||||
let token_text = &stream_token.token.text.clone();
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||
StreamState::BufferTrailing => {
|
||||
let infix_text = "\"content\":\"";
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
// keep capturing until we find the infix text
|
||||
match json_buffer.find(infix_text) {
|
||||
Some(content_key_index) => {
|
||||
json_buffer =
|
||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||
}
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// if there is leftover text after removing the infix text, we need to send it
|
||||
if !json_buffer.is_empty() {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
Some(json_buffer.clone()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
}));
|
||||
}
|
||||
// cleanup the buffers
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
}
|
||||
|
||||
let response_stream = response_stream.chain(futures::stream::once(async {
|
||||
Ok(Event::default().data("[DONE]"))
|
||||
}));
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
@ -1246,17 +1389,33 @@ async fn chat_completions(
|
||||
if let Value::Object(ref mut props) = arguments {
|
||||
props.remove("_name");
|
||||
}
|
||||
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
match name.as_str() {
|
||||
"no_tool" => {
|
||||
// parse the content message
|
||||
let content_message = arguments
|
||||
.get("content")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
InferError::ToolError(
|
||||
"No `content` found in generated text".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_string();
|
||||
(None, Some(content_message))
|
||||
}
|
||||
_ => {
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
@ -2323,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
(
|
||||
@ -2500,8 +2660,8 @@ mod tests {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (inputs, _grammar, using_tools) = result.unwrap();
|
||||
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
||||
assert_eq!(using_tools, true);
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
[toolchain]
|
||||
# Released on: June 13, 2024
|
||||
# https://releases.rs/docs/1.79.0/
|
||||
channel = "1.80.0"
|
||||
channel = "1.80.1"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
29
server/poetry.lock
generated
29
server/poetry.lock
generated
@ -1269,12 +1269,12 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.4.0"
|
||||
version = "0.6.0"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"},
|
||||
{file = "moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:f28fd2a56c3ac7bfe74bc44cc7c8c0791a2644ad689b084ea4ed6decb7f41c25"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1284,16 +1284,16 @@ triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.4.0"
|
||||
version = "0.6.0"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"},
|
||||
{file = "moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:db475948fd9f7a8647aa3f73256ff4d3bb111425305bcd0b0d3559ccc75b8937"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1303,16 +1303,16 @@ triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.4.0"
|
||||
version = "0.6.0"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"},
|
||||
{file = "moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:364be07c06aafbab1f51d9e26d9a4ff658defe1462a4c645abaf7b895ed163a8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1322,16 +1322,16 @@ triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.4.0"
|
||||
version = "0.6.0"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"},
|
||||
{file = "moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:81e7fa25fb5ed5336f5151994f5e3f600df7e166fe013576968c59415e442894"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1341,7 +1341,7 @@ triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
@ -3402,11 +3402,6 @@ files = [
|
||||
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
|
||||
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
|
||||
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
|
||||
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
|
||||
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
|
||||
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
|
||||
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
|
||||
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -47,10 +47,10 @@ marlin-kernels = [
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
moe-kernels = [
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
rich = "^13.7.1"
|
||||
|
||||
@ -82,3 +82,6 @@ requires = [
|
||||
"poetry-core>=1.0.0",
|
||||
]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
@ -24,10 +24,8 @@ class KVCache:
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
|
||||
if (
|
||||
dtype == torch.float8_e5m2
|
||||
and (ATTENTION != "flashinfer"
|
||||
or SYSTEM != "cuda")
|
||||
if dtype == torch.float8_e5m2 and (
|
||||
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
||||
):
|
||||
raise ValueError(
|
||||
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
||||
|
@ -43,7 +43,7 @@ def can_use_gptq_marlin(
|
||||
and quant_method in {"awq", "gptq"}
|
||||
and bits in GPTQ_MARLIN_BITS
|
||||
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
# We only suppord asymmetric quantization for AWQ.
|
||||
# We only support asymmetric quantization for AWQ.
|
||||
and (sym or quant_method == "awq")
|
||||
)
|
||||
|
||||
|
@ -210,11 +210,17 @@ class SparseMoELayer(nn.Module):
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||
cls = UnquantizedSparseMoELayer
|
||||
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||
elif isinstance(
|
||||
weights.loader, GPTQMarlinWeightsLoader
|
||||
) and can_use_marlin_moe_gemm(
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
):
|
||||
cls = GPTQMarlinSparseMoELayer
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
|
||||
)
|
||||
|
||||
log_once(
|
||||
|
@ -34,9 +34,10 @@ def can_use_marlin_moe_gemm(
|
||||
SYSTEM == "cuda"
|
||||
and fused_marlin_moe is not None
|
||||
and has_sm_8_0
|
||||
and quantize == "gptq"
|
||||
and quant_method == "gptq"
|
||||
and sym
|
||||
and quantize in {"awq", "gptq"}
|
||||
and quant_method in {"awq", "gptq"}
|
||||
# We only support asymmetric quantization for AWQ.
|
||||
and (sym or quant_method == "awq")
|
||||
)
|
||||
|
||||
|
||||
@ -72,10 +73,15 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if not (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
and can_use_marlin_moe_gemm(
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
|
||||
f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported"
|
||||
)
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
@ -102,17 +108,24 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
hidden_states=x,
|
||||
w1=self.gate_up_proj.qweight,
|
||||
w2=self.down_proj.qweight,
|
||||
g_idx1=self.gate_up_proj.g_idx,
|
||||
g_idx2=self.down_proj.g_idx,
|
||||
perm1=self.gate_up_proj.perm,
|
||||
perm2=self.down_proj.perm,
|
||||
w1_scale=self.gate_up_proj.scales,
|
||||
w2_scale=self.down_proj.scales,
|
||||
is_full_k1=self.gate_up_proj.is_full_k,
|
||||
is_full_k2=self.down_proj.is_full_k,
|
||||
w1_zeros=(
|
||||
self.gate_up_proj.qzeros
|
||||
if self.gate_up_proj.qzeros.numel() > 0
|
||||
else None
|
||||
),
|
||||
w2_zeros=(
|
||||
self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None
|
||||
),
|
||||
g_idx1=self.gate_up_proj.g_idx,
|
||||
g_idx2=self.down_proj.g_idx,
|
||||
sort_indices1=self.gate_up_proj.perm,
|
||||
sort_indices2=self.down_proj.perm,
|
||||
is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
|
Loading…
Reference in New Issue
Block a user