mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Merge branch 'main' into fp8_kvcache
This commit is contained in:
commit
084de9907c
371
.github/workflows/build.yaml
vendored
371
.github/workflows/build.yaml
vendored
@ -18,51 +18,29 @@ on:
|
|||||||
- "Cargo.lock"
|
- "Cargo.lock"
|
||||||
- "rust-toolchain.toml"
|
- "rust-toolchain.toml"
|
||||||
- "Dockerfile"
|
- "Dockerfile"
|
||||||
|
- "Dockerfile_amd"
|
||||||
|
- "Dockerfile_intel"
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
start-runner:
|
|
||||||
name: Start self-hosted EC2 runner
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
AWS_REGION: us-east-1
|
|
||||||
EC2_AMI_ID: ami-0789b6925c11b1fb2
|
|
||||||
EC2_INSTANCE_TYPE: g5.12xlarge
|
|
||||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
|
||||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
|
||||||
outputs:
|
|
||||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
|
||||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
|
||||||
steps:
|
|
||||||
- name: Configure AWS credentials
|
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
|
||||||
with:
|
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
|
||||||
aws-region: ${{ env.AWS_REGION }}
|
|
||||||
- name: Start EC2 runner
|
|
||||||
id: start-ec2-runner
|
|
||||||
uses: philschmid/philschmid-ec2-github-runner@main
|
|
||||||
with:
|
|
||||||
mode: start
|
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
|
||||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
|
||||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
|
||||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
|
||||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
|
||||||
aws-resource-tags: > # optional, requires additional permissions
|
|
||||||
[
|
|
||||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
|
||||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
|
||||||
]
|
|
||||||
|
|
||||||
build-and-push-image:
|
build-and-push-image:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-build-and-push-image-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-build-and-push-image-${{ matrix.name }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
needs: start-runner # required to start the main job when the runner is ready
|
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- name: "cuda"
|
||||||
|
label: ""
|
||||||
|
dockerfile: "Dockerfile"
|
||||||
|
- name: "amd"
|
||||||
|
label: "-rocm"
|
||||||
|
dockerfile: "Dockerfile_amd"
|
||||||
|
- name: "intel"
|
||||||
|
label: "-intel"
|
||||||
|
dockerfile: "Dockerfile_intel"
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
@ -73,16 +51,19 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Tailscale
|
||||||
|
uses: huggingface/tailscale-action@main
|
||||||
|
with:
|
||||||
|
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||||
|
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
|
||||||
|
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||||
- name: Initialize Docker Buildx
|
- name: Initialize Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2.0.0
|
uses: docker/setup-buildx-action@v2.0.0
|
||||||
with:
|
with:
|
||||||
install: true
|
install: true
|
||||||
- name: Inject slug/short variables
|
|
||||||
uses: rlespinasse/github-slug-action@v4.4.1
|
|
||||||
- name: Tailscale
|
|
||||||
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
|
||||||
with:
|
|
||||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
@ -112,7 +93,7 @@ jobs:
|
|||||||
images: |
|
images: |
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ matrix.label }}
|
||||||
# If main, release or tag
|
# If main, release or tag
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
if: ${{ github.event_name != 'pull_request' }}
|
if: ${{ github.event_name != 'pull_request' }}
|
||||||
@ -126,308 +107,44 @@ jobs:
|
|||||||
ghcr.io/huggingface/text-generation-inference
|
ghcr.io/huggingface/text-generation-inference
|
||||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
type=semver,pattern={{version}}
|
type=semver,pattern={{version}}${{ matrix.label }}
|
||||||
type=semver,pattern={{major}}.{{minor}}
|
type=semver,pattern={{major}}.{{minor}}${{ matrix.label }}
|
||||||
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
type=raw,value=latest${{ matrix.label }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ matrix.label }}
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
id: build-and-push
|
id: build-and-push
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: Dockerfile
|
file: ${{ matrix.dockerfile }}
|
||||||
push: true
|
push: true
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
build-args: |
|
build-args: |
|
||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ matrix.label }}
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
network: host
|
||||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ matrix.label }},mode=min
|
||||||
|
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ matrix.label }},mode=min
|
||||||
integration-tests:
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
needs:
|
|
||||||
- start-runner
|
|
||||||
- build-and-push-image # Wait for the docker image to be built
|
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
|
||||||
env:
|
|
||||||
DOCKER_VOLUME: /cache
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Inject slug/short variables
|
|
||||||
uses: rlespinasse/github-slug-action@v4.4.1
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
if: matrix.name == 'cuda'
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.9
|
python-version: 3.9
|
||||||
- name: Tailscale
|
|
||||||
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
|
||||||
with:
|
|
||||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
|
||||||
- name: Prepare disks
|
|
||||||
run: |
|
|
||||||
sudo mkfs -t ext4 /dev/nvme1n1
|
|
||||||
sudo mkdir ${{ env.DOCKER_VOLUME }}
|
|
||||||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
|
||||||
- name: Install
|
- name: Install
|
||||||
|
if: matrix.name == 'cuda'
|
||||||
run: |
|
run: |
|
||||||
make install-integration-tests
|
make install-integration-tests
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
|
if: matrix.name == 'cuda'
|
||||||
run: |
|
run: |
|
||||||
|
export DOCKER_VOLUME=/mnt/cache
|
||||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
- name: Tailscale Wait
|
||||||
build-and-push-image-rocm:
|
if: ${{ failure() || runner.debug == '1' }}
|
||||||
concurrency:
|
uses: huggingface/tailscale-action@main
|
||||||
group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
needs:
|
|
||||||
- start-runner
|
|
||||||
- build-and-push-image # Wait for the main docker image to be built
|
|
||||||
- integration-tests # Wait for the main integration-tests
|
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
packages: write
|
|
||||||
# This is used to complete the identity challenge
|
|
||||||
# with sigstore/fulcio when running outside of PRs.
|
|
||||||
id-token: write
|
|
||||||
security-events: write
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Initialize Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2.0.0
|
|
||||||
with:
|
with:
|
||||||
install: true
|
waitForSSH: true
|
||||||
- name: Inject slug/short variables
|
|
||||||
uses: rlespinasse/github-slug-action@v4.4.1
|
|
||||||
- name: Tailscale
|
|
||||||
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
|
||||||
with:
|
|
||||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
|
||||||
- name: Login to GitHub Container Registry
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
uses: docker/login-action@v2
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
- name: Login to internal Container Registry
|
|
||||||
uses: docker/login-action@v2.1.0
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
|
|
||||||
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
|
|
||||||
registry: registry.internal.huggingface.tech
|
|
||||||
- name: Login to Azure Container Registry
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
uses: docker/login-action@v2.1.0
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
|
||||||
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
|
||||||
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
|
||||||
# If pull request
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
if: ${{ github.event_name == 'pull_request' }}
|
|
||||||
id: meta-pr
|
|
||||||
uses: docker/metadata-action@v4.3.0
|
|
||||||
with:
|
|
||||||
images: |
|
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
|
||||||
tags: |
|
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
|
||||||
# If main, release or tag
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
if: ${{ github.event_name != 'pull_request' }}
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v4.3.0
|
|
||||||
with:
|
|
||||||
flavor: |
|
|
||||||
latest=false
|
|
||||||
images: |
|
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
|
||||||
ghcr.io/huggingface/text-generation-inference
|
|
||||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
|
||||||
tags: |
|
|
||||||
type=semver,pattern={{version}}-rocm
|
|
||||||
type=semver,pattern={{major}}.{{minor}}-rocm
|
|
||||||
type=raw,value=latest-rocm,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
|
||||||
- name: Build and push Docker image
|
|
||||||
id: build-and-push
|
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: Dockerfile_amd
|
|
||||||
push: true
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
build-args: |
|
|
||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
|
||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
|
||||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
|
||||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
|
||||||
|
|
||||||
build-and-push-image-intel:
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-build-and-push-image-intel-${{ github.head_ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
needs:
|
|
||||||
- start-runner
|
|
||||||
- build-and-push-image # Wait for the main docker image to be built
|
|
||||||
- integration-tests # Wait for the main integration-tests
|
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
packages: write
|
|
||||||
# This is used to complete the identity challenge
|
|
||||||
# with sigstore/fulcio when running outside of PRs.
|
|
||||||
id-token: write
|
|
||||||
security-events: write
|
|
||||||
outputs:
|
|
||||||
# env is not available in the later `container:`, but previous job outputs are.
|
|
||||||
short_sha: ${{ env.GITHUB_SHA_SHORT }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Initialize Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2.0.0
|
|
||||||
with:
|
|
||||||
install: true
|
|
||||||
- name: Inject slug/short variables
|
|
||||||
uses: rlespinasse/github-slug-action@v4.4.1
|
|
||||||
- name: Tailscale
|
|
||||||
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
|
||||||
with:
|
|
||||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
|
||||||
- name: Login to GitHub Container Registry
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
uses: docker/login-action@v2
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
- name: Login to internal Container Registry
|
|
||||||
uses: docker/login-action@v2.1.0
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
|
|
||||||
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
|
|
||||||
registry: registry.internal.huggingface.tech
|
|
||||||
- name: Login to Azure Container Registry
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
uses: docker/login-action@v2.1.0
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
|
||||||
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
|
||||||
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
|
||||||
# If pull request
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
if: ${{ github.event_name == 'pull_request' }}
|
|
||||||
id: meta-pr
|
|
||||||
uses: docker/metadata-action@v4.3.0
|
|
||||||
with:
|
|
||||||
images: |
|
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
|
||||||
tags: |
|
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
|
||||||
# If main, release or tag
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
if: ${{ github.event_name != 'pull_request' }}
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v4.3.0
|
|
||||||
with:
|
|
||||||
flavor: |
|
|
||||||
latest=false
|
|
||||||
images: |
|
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
|
||||||
ghcr.io/huggingface/text-generation-inference
|
|
||||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
|
||||||
tags: |
|
|
||||||
type=semver,pattern={{version}}-intel
|
|
||||||
type=semver,pattern={{major}}.{{minor}}-intel
|
|
||||||
type=raw,value=latest-intel,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
|
||||||
- name: Build and push Docker image
|
|
||||||
id: build-and-push
|
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: Dockerfile_intel
|
|
||||||
push: true
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
build-args: |
|
|
||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
|
||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
|
||||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
|
|
||||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
|
|
||||||
|
|
||||||
stop-runner:
|
|
||||||
name: Stop self-hosted EC2 runner
|
|
||||||
needs:
|
|
||||||
- start-runner
|
|
||||||
- build-and-push-image
|
|
||||||
- build-and-push-image-rocm
|
|
||||||
- build-and-push-image-intel
|
|
||||||
- integration-tests
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
AWS_REGION: us-east-1
|
|
||||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
|
||||||
steps:
|
|
||||||
- name: Configure AWS credentials
|
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
|
||||||
with:
|
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
|
||||||
aws-region: ${{ env.AWS_REGION }}
|
|
||||||
- name: Stop EC2 runner
|
|
||||||
uses: philschmid/philschmid-ec2-github-runner@main
|
|
||||||
with:
|
|
||||||
mode: stop
|
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
|
||||||
label: ${{ needs.start-runner.outputs.label }}
|
|
||||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
|
||||||
|
|
||||||
# TODO: Move this to `build_amd.yml` (and `build_nvidia.yml`)
|
|
||||||
|
|
||||||
# integration-tests-rocm:
|
|
||||||
# concurrency:
|
|
||||||
# group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
|
||||||
# cancel-in-progress: true
|
|
||||||
# needs:
|
|
||||||
# - start-runner
|
|
||||||
# - build-and-push-image
|
|
||||||
# - integration-tests
|
|
||||||
# - build-and-push-image-rocm
|
|
||||||
# - stop-runner
|
|
||||||
# runs-on: [self-hosted, amd-gpu, multi-gpu, mi300]
|
|
||||||
# container:
|
|
||||||
# image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm
|
|
||||||
# options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache
|
|
||||||
# env:
|
|
||||||
# DOCKER_VOLUME: /cache
|
|
||||||
# steps:
|
|
||||||
# - name: ROCM-SMI
|
|
||||||
# run: |
|
|
||||||
# rocm-smi
|
|
||||||
# - name: ROCM-INFO
|
|
||||||
# run: |
|
|
||||||
# rocminfo | grep "Agent" -A 14
|
|
||||||
# - name: Show ROCR environment
|
|
||||||
# run: |
|
|
||||||
# echo "ROCR: $ROCR_VISIBLE_DEVICES"
|
|
||||||
# - name: Install
|
|
||||||
# run: |
|
|
||||||
# make install-integration-tests
|
|
||||||
# - name: Run tests
|
|
||||||
# run: |
|
|
||||||
# export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
|
||||||
# pytest -s -vv integration-tests
|
|
||||||
|
10
.github/workflows/tests.yaml
vendored
10
.github/workflows/tests.yaml
vendored
@ -33,9 +33,9 @@ jobs:
|
|||||||
- name: Install Rust
|
- name: Install Rust
|
||||||
uses: actions-rs/toolchain@v1
|
uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
# Released on: 02 May, 2024
|
# Released on: June 13, 2024
|
||||||
# https://releases.rs/docs/1.78.0/
|
# https://releases.rs/docs/1.79.0/
|
||||||
toolchain: 1.78.0
|
toolchain: 1.79.0
|
||||||
override: true
|
override: true
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
- name: Install Protoc
|
- name: Install Protoc
|
||||||
@ -68,11 +68,11 @@ jobs:
|
|||||||
~/.cargo/git
|
~/.cargo/git
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
make install
|
make install-cpu
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Pre-commit checks
|
- name: Pre-commit checks
|
||||||
run: |
|
run: |
|
||||||
|
18
.github/workflows/trufflehog.yml
vendored
Normal file
18
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
on:
|
||||||
|
push:
|
||||||
|
|
||||||
|
name: Secret Leaks
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
trufflehog:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Secret Scanning
|
||||||
|
uses: trufflesecurity/trufflehog@main
|
133
CODE_OF_CONDUCT.md
Normal file
133
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
|
||||||
|
# Contributor Covenant Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
We as members, contributors, and leaders pledge to make participation in our
|
||||||
|
community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||||
|
identity and expression, level of experience, education, socio-economic status,
|
||||||
|
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||||
|
identity and orientation.
|
||||||
|
|
||||||
|
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||||
|
diverse, inclusive, and healthy community.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to a positive environment for our
|
||||||
|
community include:
|
||||||
|
|
||||||
|
* Demonstrating empathy and kindness toward other people
|
||||||
|
* Being respectful of differing opinions, viewpoints, and experiences
|
||||||
|
* Giving and gracefully accepting constructive feedback
|
||||||
|
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||||
|
and learning from the experience
|
||||||
|
* Focusing on what is best not just for us as individuals, but for the overall
|
||||||
|
community
|
||||||
|
|
||||||
|
Examples of unacceptable behavior include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||||
|
any kind
|
||||||
|
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or email address,
|
||||||
|
without their explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Enforcement Responsibilities
|
||||||
|
|
||||||
|
Community leaders are responsible for clarifying and enforcing our standards of
|
||||||
|
acceptable behavior and will take appropriate and fair corrective action in
|
||||||
|
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||||
|
or harmful.
|
||||||
|
|
||||||
|
Community leaders have the right and responsibility to remove, edit, or reject
|
||||||
|
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||||
|
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||||
|
decisions when appropriate.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all community spaces, and also applies when
|
||||||
|
an individual is officially representing the community in public spaces.
|
||||||
|
Examples of representing our community include using an official e-mail address,
|
||||||
|
posting via an official social media account, or acting as an appointed
|
||||||
|
representative at an online or offline event.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported to the community leaders responsible for enforcement at
|
||||||
|
feedback@huggingface.co.
|
||||||
|
All complaints will be reviewed and investigated promptly and fairly.
|
||||||
|
|
||||||
|
All community leaders are obligated to respect the privacy and security of the
|
||||||
|
reporter of any incident.
|
||||||
|
|
||||||
|
## Enforcement Guidelines
|
||||||
|
|
||||||
|
Community leaders will follow these Community Impact Guidelines in determining
|
||||||
|
the consequences for any action they deem in violation of this Code of Conduct:
|
||||||
|
|
||||||
|
### 1. Correction
|
||||||
|
|
||||||
|
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||||
|
unprofessional or unwelcome in the community.
|
||||||
|
|
||||||
|
**Consequence**: A private, written warning from community leaders, providing
|
||||||
|
clarity around the nature of the violation and an explanation of why the
|
||||||
|
behavior was inappropriate. A public apology may be requested.
|
||||||
|
|
||||||
|
### 2. Warning
|
||||||
|
|
||||||
|
**Community Impact**: A violation through a single incident or series of
|
||||||
|
actions.
|
||||||
|
|
||||||
|
**Consequence**: A warning with consequences for continued behavior. No
|
||||||
|
interaction with the people involved, including unsolicited interaction with
|
||||||
|
those enforcing the Code of Conduct, for a specified period of time. This
|
||||||
|
includes avoiding interactions in community spaces as well as external channels
|
||||||
|
like social media. Violating these terms may lead to a temporary or permanent
|
||||||
|
ban.
|
||||||
|
|
||||||
|
### 3. Temporary Ban
|
||||||
|
|
||||||
|
**Community Impact**: A serious violation of community standards, including
|
||||||
|
sustained inappropriate behavior.
|
||||||
|
|
||||||
|
**Consequence**: A temporary ban from any sort of interaction or public
|
||||||
|
communication with the community for a specified period of time. No public or
|
||||||
|
private interaction with the people involved, including unsolicited interaction
|
||||||
|
with those enforcing the Code of Conduct, is allowed during this period.
|
||||||
|
Violating these terms may lead to a permanent ban.
|
||||||
|
|
||||||
|
### 4. Permanent Ban
|
||||||
|
|
||||||
|
**Community Impact**: Demonstrating a pattern of violation of community
|
||||||
|
standards, including sustained inappropriate behavior, harassment of an
|
||||||
|
individual, or aggression toward or disparagement of classes of individuals.
|
||||||
|
|
||||||
|
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||||
|
community.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||||
|
version 2.1, available at
|
||||||
|
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||||
|
|
||||||
|
Community Impact Guidelines were inspired by
|
||||||
|
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see the FAQ at
|
||||||
|
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
||||||
|
[https://www.contributor-covenant.org/translations][translations].
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
||||||
|
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||||
|
[Mozilla CoC]: https://github.com/mozilla/diversity
|
||||||
|
[FAQ]: https://www.contributor-covenant.org/faq
|
||||||
|
[translations]: https://www.contributor-covenant.org/translations
|
120
CONTRIBUTING.md
Normal file
120
CONTRIBUTING.md
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
<!---
|
||||||
|
Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Contribute to text-generation-inference
|
||||||
|
|
||||||
|
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||||
|
contributions are not the only way to help the community. Answering questions, helping
|
||||||
|
others, and improving the documentation are also immensely valuable.
|
||||||
|
|
||||||
|
It also helps us if you spread the word! Reference the library in blog posts
|
||||||
|
about the awesome projects it made possible, shout out on Twitter every time it has
|
||||||
|
helped you, or simply ⭐️ the repository to say thank you.
|
||||||
|
|
||||||
|
However you choose to contribute, please be mindful and respect our
|
||||||
|
[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md).
|
||||||
|
|
||||||
|
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
||||||
|
|
||||||
|
## Ways to contribute
|
||||||
|
|
||||||
|
There are several ways you can contribute to text-generation-inference.
|
||||||
|
|
||||||
|
* Fix outstanding issues with the existing code.
|
||||||
|
* Submit issues related to bugs or desired new features.
|
||||||
|
* Contribute to the examples or to the documentation.
|
||||||
|
|
||||||
|
> All contributions are equally valuable to the community. 🥰
|
||||||
|
|
||||||
|
## Fixing outstanding issues
|
||||||
|
|
||||||
|
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open
|
||||||
|
a Pull Request!
|
||||||
|
|
||||||
|
## Submitting a bug-related issue or feature request
|
||||||
|
|
||||||
|
Do your best to follow these guidelines when submitting a bug-related issue or a feature
|
||||||
|
request. It will make it easier for us to come back to you quickly and with good
|
||||||
|
feedback.
|
||||||
|
|
||||||
|
### Did you find a bug?
|
||||||
|
|
||||||
|
The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter.
|
||||||
|
|
||||||
|
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||||
|
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the
|
||||||
|
library itself, and not your code.
|
||||||
|
|
||||||
|
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so
|
||||||
|
we can quickly resolve it:
|
||||||
|
|
||||||
|
* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies).
|
||||||
|
* A short, self-contained, code snippet that allows us to reproduce the bug.
|
||||||
|
* The *full* traceback if an exception is raised.
|
||||||
|
* Attach any other additional information, like screenshots, you think may help.
|
||||||
|
|
||||||
|
To get the OS and software versions automatically, you can re-run the launcher with the `--env` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
text-generation-launcher --env
|
||||||
|
```
|
||||||
|
|
||||||
|
This will precede the launch of the model with the information relative to your environment. We recommend pasting
|
||||||
|
that in your issue report.
|
||||||
|
|
||||||
|
### Do you want a new feature?
|
||||||
|
|
||||||
|
If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe:
|
||||||
|
|
||||||
|
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it
|
||||||
|
a feature related to something you need for a project? Is it something you worked on and think it could benefit
|
||||||
|
the community?
|
||||||
|
|
||||||
|
Whatever it is, we'd love to hear about it!
|
||||||
|
|
||||||
|
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better
|
||||||
|
we'll be able to help you.
|
||||||
|
3. Provide a *code snippet* that demonstrates the feature's usage.
|
||||||
|
4. If the feature is related to a paper, please include a link.
|
||||||
|
|
||||||
|
If your issue is well written we're already 80% of the way there by the time you create it.
|
||||||
|
|
||||||
|
We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE)
|
||||||
|
to help you get started with your issue.
|
||||||
|
|
||||||
|
## Do you want to implement a new model?
|
||||||
|
|
||||||
|
New models are constantly released and if you want to implement a new model, please provide the following information:
|
||||||
|
|
||||||
|
* A short description of the model and a link to the paper.
|
||||||
|
* Link to the implementation if it is open-sourced.
|
||||||
|
* Link to the model weights if they are available.
|
||||||
|
|
||||||
|
If you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference!
|
||||||
|
|
||||||
|
## Do you want to add documentation?
|
||||||
|
|
||||||
|
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know
|
||||||
|
how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be
|
||||||
|
happy to make the changes or help you make a contribution if you're interested!
|
||||||
|
|
||||||
|
## I want to become a maintainer of the project. How do I get there?
|
||||||
|
|
||||||
|
TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have
|
||||||
|
motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference
|
||||||
|
service.
|
||||||
|
|
||||||
|
If you are such an individual (or organization), please reach out to us and let's collaborate.
|
781
Cargo.lock
generated
781
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
14
Cargo.toml
14
Cargo.toml
@ -9,19 +9,29 @@ members = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "2.0.2"
|
version = "2.0.5-dev0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
|
base64 = "0.22.0"
|
||||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
|
incremental = true
|
||||||
|
|
||||||
|
[profile.release-binary]
|
||||||
|
inherits = "release"
|
||||||
debug = 1
|
debug = 1
|
||||||
incremental = true
|
incremental = true
|
||||||
|
panic = "abort"
|
||||||
|
|
||||||
|
[profile.release-opt]
|
||||||
|
inherits = "release"
|
||||||
|
debug = 0
|
||||||
|
incremental = false
|
||||||
lto = "fat"
|
lto = "fat"
|
||||||
opt-level = 3
|
opt-level = 3
|
||||||
codegen-units = 1
|
codegen-units = 1
|
||||||
panic = "abort"
|
|
||||||
|
41
Dockerfile
41
Dockerfile
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -15,9 +15,6 @@ RUN cargo chef prepare --recipe-path recipe.json
|
|||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
|
|
||||||
ARG GIT_SHA
|
|
||||||
ARG DOCKER_LABEL
|
|
||||||
|
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -25,7 +22,10 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||||||
rm -f $PROTOC_ZIP
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
RUN cargo chef cook --release --recipe-path recipe.json
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
@ -33,7 +33,7 @@ COPY proto proto
|
|||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --release
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
# Python builder
|
# Python builder
|
||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
@ -137,6 +137,13 @@ COPY server/Makefile-eetq Makefile
|
|||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
||||||
|
|
||||||
|
# Build marlin kernels
|
||||||
|
FROM kernel-builder as marlin-kernels-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/marlin/ .
|
||||||
|
# Build specific version of transformers
|
||||||
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
# Build Transformers CUDA kernels
|
||||||
FROM kernel-builder as custom-kernels-builder
|
FROM kernel-builder as custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
@ -193,7 +200,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib
|
|||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
# Copy build artifacts from flash attention v2 builder
|
||||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
@ -205,6 +212,8 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
|
|||||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
# Copy build artifacts from eetq kernels builder
|
# Copy build artifacts from eetq kernels builder
|
||||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
# Copy build artifacts from marlin kernels builder
|
||||||
|
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
@ -225,18 +234,22 @@ RUN cd server && \
|
|||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
|
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Deps before the binaries
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
# The binaries change on every build given we burn the SHA into them
|
||||||
# Install router
|
# The deps change less often.
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
|
||||||
# Install launcher
|
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
g++ \
|
g++ \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# AWS Sagemaker compatible image
|
||||||
FROM base as sagemaker
|
FROM base as sagemaker
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -15,9 +15,6 @@ RUN cargo chef prepare --recipe-path recipe.json
|
|||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
|
|
||||||
ARG GIT_SHA
|
|
||||||
ARG DOCKER_LABEL
|
|
||||||
|
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -25,7 +22,10 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||||||
rm -f $PROTOC_ZIP
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
RUN cargo chef cook --release --recipe-path recipe.json
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
@ -33,7 +33,7 @@ COPY proto proto
|
|||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --release
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
# Text Generation Inference base image for RoCm
|
||||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
||||||
@ -193,11 +193,11 @@ RUN cd server && \
|
|||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# AWS Sagemaker compatible image
|
||||||
FROM base as sagemaker
|
FROM base as sagemaker
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -14,9 +14,6 @@ RUN cargo chef prepare --recipe-path recipe.json
|
|||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
|
|
||||||
ARG GIT_SHA
|
|
||||||
ARG DOCKER_LABEL
|
|
||||||
|
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -24,7 +21,10 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||||||
rm -f $PROTOC_ZIP
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
RUN cargo chef cook --release --recipe-path recipe.json
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
@ -32,7 +32,7 @@ COPY proto proto
|
|||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --release
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
@ -43,11 +43,12 @@ USER root
|
|||||||
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
|
||||||
|
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null
|
||||||
|
|
||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
|
||||||
RUN apt-get update && apt install -y intel-basekit xpu-smi
|
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
@ -56,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||||||
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
@ -65,7 +66,7 @@ COPY server server
|
|||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_intel.txt && \
|
||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
||||||
@ -75,13 +76,17 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l
|
|||||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
||||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
|
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||||
|
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||||
|
|
||||||
|
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# Final image
|
# Final image
|
||||||
FROM base
|
FROM base
|
||||||
|
17
Makefile
17
Makefile
@ -1,12 +1,8 @@
|
|||||||
install-server:
|
install-server:
|
||||||
cd server && make install
|
cd server && make install
|
||||||
|
|
||||||
install-custom-kernels:
|
install-server-cpu:
|
||||||
if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi
|
cd server && make install-server
|
||||||
|
|
||||||
install-integration-tests:
|
|
||||||
cd integration-tests && pip install -r requirements.txt
|
|
||||||
cd clients/python && pip install .
|
|
||||||
|
|
||||||
install-router:
|
install-router:
|
||||||
cd router && cargo install --path .
|
cd router && cargo install --path .
|
||||||
@ -17,7 +13,10 @@ install-launcher:
|
|||||||
install-benchmark:
|
install-benchmark:
|
||||||
cd benchmark && cargo install --path .
|
cd benchmark && cargo install --path .
|
||||||
|
|
||||||
install: install-server install-router install-launcher install-custom-kernels
|
install: install-server install-router install-launcher
|
||||||
|
|
||||||
|
|
||||||
|
install-cpu: install-server-cpu install-router install-launcher
|
||||||
|
|
||||||
server-dev:
|
server-dev:
|
||||||
cd server && make run-dev
|
cd server && make run-dev
|
||||||
@ -28,6 +27,10 @@ router-dev:
|
|||||||
rust-tests: install-router install-launcher
|
rust-tests: install-router install-launcher
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
|
install-integration-tests:
|
||||||
|
cd integration-tests && pip install -r requirements.txt
|
||||||
|
cd clients/python && pip install .
|
||||||
|
|
||||||
integration-tests: install-integration-tests
|
integration-tests: install-integration-tests
|
||||||
pytest -s -vv -m "not private" integration-tests
|
pytest -s -vv -m "not private" integration-tests
|
||||||
|
|
||||||
|
@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
|
|||||||
"Lowest: {:.2} {unit}",
|
"Lowest: {:.2} {unit}",
|
||||||
data.iter()
|
data.iter()
|
||||||
.min_by(|a, b| a.total_cmp(b))
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN)
|
.unwrap_or(&f64::NAN)
|
||||||
),
|
),
|
||||||
Style::default().fg(Color::Reset),
|
Style::default().fg(Color::Reset),
|
||||||
)]),
|
)]),
|
||||||
@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
|
|||||||
"Highest: {:.2} {unit}",
|
"Highest: {:.2} {unit}",
|
||||||
data.iter()
|
data.iter()
|
||||||
.max_by(|a, b| a.total_cmp(b))
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN)
|
.unwrap_or(&f64::NAN)
|
||||||
),
|
),
|
||||||
Style::default().fg(Color::Reset),
|
Style::default().fg(Color::Reset),
|
||||||
)]),
|
)]),
|
||||||
@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>(
|
|||||||
let min_latency: f64 = *latency_iter
|
let min_latency: f64 = *latency_iter
|
||||||
.clone()
|
.clone()
|
||||||
.min_by(|a, b| a.total_cmp(b))
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN);
|
.unwrap_or(&f64::NAN);
|
||||||
let max_latency: f64 = *latency_iter
|
let max_latency: f64 = *latency_iter
|
||||||
.max_by(|a, b| a.total_cmp(b))
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN);
|
.unwrap_or(&f64::NAN);
|
||||||
let min_throughput: f64 = *throughput_iter
|
let min_throughput: f64 = *throughput_iter
|
||||||
.clone()
|
.clone()
|
||||||
.min_by(|a, b| a.total_cmp(b))
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN);
|
.unwrap_or(&f64::NAN);
|
||||||
let max_throughput: f64 = *throughput_iter
|
let max_throughput: f64 = *throughput_iter
|
||||||
.max_by(|a, b| a.total_cmp(b))
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN);
|
.unwrap_or(&f64::NAN);
|
||||||
|
|
||||||
// Char min max values
|
// Char min max values
|
||||||
let min_x = if zoom {
|
let min_x = if zoom {
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use text_generation_client::{
|
use text_generation_client::v3::{
|
||||||
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
|
||||||
StoppingCriteriaParameters,
|
StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
|
use text_generation_client::{Chunk, ClientError, Input};
|
||||||
use tokenizers::{Tokenizer, TruncationDirection};
|
use tokenizers::{Tokenizer, TruncationDirection};
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
@ -142,6 +143,9 @@ async fn prefill(
|
|||||||
.map(|id| Request {
|
.map(|id| Request {
|
||||||
id: id.into(),
|
id: id.into(),
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: vec![Chunk::Text(sequence.clone()).into()],
|
||||||
|
}),
|
||||||
inputs: sequence.clone(),
|
inputs: sequence.clone(),
|
||||||
truncate: sequence_length,
|
truncate: sequence_length,
|
||||||
parameters: Some(parameters.clone()),
|
parameters: Some(parameters.clone()),
|
||||||
@ -151,6 +155,8 @@ async fn prefill(
|
|||||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||||
}),
|
}),
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
|
blocks: vec![],
|
||||||
|
slots: vec![],
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@ -159,6 +165,7 @@ async fn prefill(
|
|||||||
requests,
|
requests,
|
||||||
size: batch_size,
|
size: batch_size,
|
||||||
max_tokens: batch_size * (sequence_length + decode_length),
|
max_tokens: batch_size * (sequence_length + decode_length),
|
||||||
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run prefill
|
// Run prefill
|
||||||
|
@ -8,7 +8,7 @@ use crate::app::App;
|
|||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crossterm::ExecutableCommand;
|
use crossterm::ExecutableCommand;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tui::backend::CrosstermBackend;
|
use tui::backend::CrosstermBackend;
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
/// and: https://github.com/orhun/rust-tui-template
|
/// and: https://github.com/orhun/rust-tui-template
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::v3::ShardedClient;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
|
|||||||
let min = data
|
let min = data
|
||||||
.iter()
|
.iter()
|
||||||
.min_by(|a, b| a.total_cmp(b))
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN);
|
.unwrap_or(&f64::NAN);
|
||||||
let max = data
|
let max = data
|
||||||
.iter()
|
.iter()
|
||||||
.max_by(|a, b| a.total_cmp(b))
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN);
|
.unwrap_or(&f64::NAN);
|
||||||
(average, *min, *max)
|
(average, *min, *max)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn px(data: &[f64], p: u32) -> f64 {
|
fn px(data: &[f64], p: u32) -> f64 {
|
||||||
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
|
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
|
||||||
*data.get(i).unwrap_or(&std::f64::NAN)
|
*data.get(i).unwrap_or(&f64::NAN)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_value(value: f64, unit: &'static str) -> String {
|
fn format_value(value: f64, unit: &'static str) -> String {
|
||||||
|
@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|&p| {
|
.map(|&p| {
|
||||||
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
|
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
|
||||||
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
|
(format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.7.0"
|
||||||
|
|
||||||
DEPRECATION_WARNING = (
|
DEPRECATION_WARNING = (
|
||||||
"`text_generation` clients are deprecated and will be removed in the near future. "
|
"`text_generation` clients are deprecated and will be removed in the near future. "
|
||||||
|
@ -13,6 +13,9 @@ from text_generation.types import (
|
|||||||
Request,
|
Request,
|
||||||
Parameters,
|
Parameters,
|
||||||
Grammar,
|
Grammar,
|
||||||
|
CompletionRequest,
|
||||||
|
Completion,
|
||||||
|
CompletionComplete,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatComplete,
|
ChatComplete,
|
||||||
@ -70,6 +73,94 @@ class Client:
|
|||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given a prompt, generate a response synchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Prompt
|
||||||
|
frequency_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty
|
||||||
|
Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
max_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stream (`bool`):
|
||||||
|
Stream the response
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
"""
|
||||||
|
request = CompletionRequest(
|
||||||
|
model="tgi",
|
||||||
|
prompt=prompt,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
return Completion(**payload)
|
||||||
|
else:
|
||||||
|
return self._completion_stream_response(request)
|
||||||
|
|
||||||
|
def _completion_stream_response(self, request):
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# iterate and print stream
|
||||||
|
for byte_payload in resp.iter_lines():
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = CompletionComplete(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
@ -88,6 +179,7 @@ class Client:
|
|||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
tool_prompt: Optional[str] = None,
|
tool_prompt: Optional[str] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a list of messages, generate a response asynchronously
|
Given a list of messages, generate a response asynchronously
|
||||||
@ -130,6 +222,8 @@ class Client:
|
|||||||
A prompt to be appended before the tools
|
A prompt to be appended before the tools
|
||||||
tool_choice (`str`):
|
tool_choice (`str`):
|
||||||
The tool to use
|
The tool to use
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
|
||||||
"""
|
"""
|
||||||
request = ChatRequest(
|
request = ChatRequest(
|
||||||
@ -150,6 +244,7 @@ class Client:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_prompt=tool_prompt,
|
tool_prompt=tool_prompt,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
if not stream:
|
if not stream:
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
@ -461,6 +556,93 @@ class AsyncClient:
|
|||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout)
|
self.timeout = ClientTimeout(timeout)
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate a response asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Prompt
|
||||||
|
frequency_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty
|
||||||
|
Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
max_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stream (`bool`):
|
||||||
|
Stream the response
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
"""
|
||||||
|
request = CompletionRequest(
|
||||||
|
model="tgi",
|
||||||
|
prompt=prompt,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
return await self._completion_single_response(request)
|
||||||
|
else:
|
||||||
|
return self._completion_stream_response(request)
|
||||||
|
|
||||||
|
async def _completion_single_response(self, request):
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
payload = await resp.json()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, payload)
|
||||||
|
return Completion(**payload)
|
||||||
|
|
||||||
|
async def _completion_stream_response(self, request):
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
async for byte_payload in resp.content:
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = CompletionComplete(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
@ -479,6 +661,7 @@ class AsyncClient:
|
|||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
tool_prompt: Optional[str] = None,
|
tool_prompt: Optional[str] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||||
"""
|
"""
|
||||||
Given a list of messages, generate a response asynchronously
|
Given a list of messages, generate a response asynchronously
|
||||||
@ -521,6 +704,8 @@ class AsyncClient:
|
|||||||
A prompt to be appended before the tools
|
A prompt to be appended before the tools
|
||||||
tool_choice (`str`):
|
tool_choice (`str`):
|
||||||
The tool to use
|
The tool to use
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
|
||||||
"""
|
"""
|
||||||
request = ChatRequest(
|
request = ChatRequest(
|
||||||
@ -541,6 +726,7 @@ class AsyncClient:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_prompt=tool_prompt,
|
tool_prompt=tool_prompt,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
if not stream:
|
if not stream:
|
||||||
return await self._chat_single_response(request)
|
return await self._chat_single_response(request)
|
||||||
|
@ -46,30 +46,6 @@ class Tool(BaseModel):
|
|||||||
function: dict
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionComplete(BaseModel):
|
|
||||||
# Index of the chat completion
|
|
||||||
index: int
|
|
||||||
# Message associated with the chat completion
|
|
||||||
message: Message
|
|
||||||
# Log probabilities for the chat completion
|
|
||||||
logprobs: Optional[Any]
|
|
||||||
# Reason for completion
|
|
||||||
finish_reason: str
|
|
||||||
# Usage details of the chat completion
|
|
||||||
usage: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionComplete(BaseModel):
|
|
||||||
# Index of the chat completion
|
|
||||||
index: int
|
|
||||||
# Message associated with the chat completion
|
|
||||||
text: str
|
|
||||||
# Log probabilities for the chat completion
|
|
||||||
logprobs: Optional[Any]
|
|
||||||
# Reason for completion
|
|
||||||
finish_reason: str
|
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
class Function(BaseModel):
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
arguments: str
|
arguments: str
|
||||||
@ -95,24 +71,41 @@ class Choice(BaseModel):
|
|||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChunk(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
id: str
|
# Model identifier
|
||||||
object: str
|
|
||||||
created: int
|
|
||||||
model: str
|
model: str
|
||||||
system_fingerprint: str
|
# Prompt
|
||||||
choices: List[Choice]
|
prompt: str
|
||||||
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
# The parameter for frequency penalty. 1.0 means no penalty
|
||||||
|
# Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
# Maximum number of tokens to generate
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
# Flag to indicate streaming response
|
||||||
|
stream: bool = False
|
||||||
|
# Random sampling seed
|
||||||
|
seed: Optional[int] = None
|
||||||
|
# Sampling temperature
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
# Top-p value for nucleus sampling
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
# Stop generating tokens if a member of `stop` is generated
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatComplete(BaseModel):
|
class CompletionComplete(BaseModel):
|
||||||
# Chat completion details
|
# Index of the chat completion
|
||||||
id: str
|
index: int
|
||||||
object: str
|
# Message associated with the chat completion
|
||||||
created: int
|
text: str
|
||||||
model: str
|
# Log probabilities for the chat completion
|
||||||
system_fingerprint: str
|
logprobs: Optional[Any]
|
||||||
choices: List[ChatCompletionComplete]
|
# Reason for completion
|
||||||
usage: Any
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
class Completion(BaseModel):
|
class Completion(BaseModel):
|
||||||
@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
|
|||||||
tool_prompt: Optional[str] = None
|
tool_prompt: Optional[str] = None
|
||||||
# Choice of tool to be used
|
# Choice of tool to be used
|
||||||
tool_choice: Optional[str] = None
|
tool_choice: Optional[str] = None
|
||||||
|
# Stop generating tokens if a member of `stop` is generated
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionComplete(BaseModel):
|
||||||
|
# Index of the chat completion
|
||||||
|
index: int
|
||||||
|
# Message associated with the chat completion
|
||||||
|
message: Message
|
||||||
|
# Log probabilities for the chat completion
|
||||||
|
logprobs: Optional[Any]
|
||||||
|
# Reason for completion
|
||||||
|
finish_reason: str
|
||||||
|
# Usage details of the chat completion
|
||||||
|
usage: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatComplete(BaseModel):
|
||||||
|
# Chat completion details
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[ChatCompletionComplete]
|
||||||
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[Choice]
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
|
10
docs/README.md
Normal file
10
docs/README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
Documentation available at: https://huggingface.co/docs/text-generation-inference
|
||||||
|
|
||||||
|
## Release
|
||||||
|
|
||||||
|
When making a release, please update the latest version in the documentation with:
|
||||||
|
```
|
||||||
|
export OLD_VERSION="2\.0\.3"
|
||||||
|
export NEW_VERSION="2\.0\.4"
|
||||||
|
find . -name '*.md' -exec sed -i -e "s/$OLD_VERSION/$NEW_VERSION/g" {} \;
|
||||||
|
```
|
@ -1121,6 +1121,15 @@
|
|||||||
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||||
"example": 0.95,
|
"example": 0.95,
|
||||||
"nullable": true
|
"nullable": true
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Up to 4 sequences where the API will stop generating further tokens.",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -17,6 +17,8 @@
|
|||||||
title: Supported Models and Hardware
|
title: Supported Models and Hardware
|
||||||
- local: messages_api
|
- local: messages_api
|
||||||
title: Messages API
|
title: Messages API
|
||||||
|
- local: architecture
|
||||||
|
title: Internal Architecture
|
||||||
title: Getting started
|
title: Getting started
|
||||||
- sections:
|
- sections:
|
||||||
- local: basic_tutorials/consuming_tgi
|
- local: basic_tutorials/consuming_tgi
|
||||||
@ -39,6 +41,8 @@
|
|||||||
title: Visual Language Models
|
title: Visual Language Models
|
||||||
- local: basic_tutorials/monitoring
|
- local: basic_tutorials/monitoring
|
||||||
title: Monitoring TGI with Prometheus and Grafana
|
title: Monitoring TGI with Prometheus and Grafana
|
||||||
|
- local: basic_tutorials/train_medusa
|
||||||
|
title: Train Medusa
|
||||||
title: Tutorials
|
title: Tutorials
|
||||||
- sections:
|
- sections:
|
||||||
- local: conceptual/streaming
|
- local: conceptual/streaming
|
||||||
|
227
docs/source/architecture.md
Normal file
227
docs/source/architecture.md
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
# Text Generation Inference Architecture
|
||||||
|
|
||||||
|
This document aims at describing the architecture of Text Generation Inference (TGI), by describing the call flow between the separate components.
|
||||||
|
|
||||||
|
A high-level architecture diagram can be seen here:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
This diagram shows well there are these separate components:
|
||||||
|
|
||||||
|
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
|
||||||
|
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
|
||||||
|
- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
||||||
|
|
||||||
|
The router and the model server can be two different machines, they do not need to be deployed together.
|
||||||
|
|
||||||
|
## The Router
|
||||||
|
|
||||||
|
This component is a rust web server binary that accepts HTTP requests using the custom [HTTP API](https://huggingface.github.io/text-generation-inference/), as well as OpenAI's [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api).
|
||||||
|
The router receives the API calls and handles the "baches" logic (and introduction to batching can be found [here](https://github.com/huggingface/text-generation-inference/blob/main/router/README.md)).
|
||||||
|
It uses different strategies to reduce latency between requests and responses, especially oriented to decoding latency. It will use queues, schedulers, and block allocators to achieve that and produce batched requests that it will then be sent to the model server.
|
||||||
|
|
||||||
|
### Router's command line
|
||||||
|
|
||||||
|
The router command line will be the way to pass parameters to it (it does not rely on configuration file):
|
||||||
|
|
||||||
|
```
|
||||||
|
Text Generation Webserver
|
||||||
|
|
||||||
|
Usage: text-generation-router [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
|
||||||
|
[env: MAX_CONCURRENT_REQUESTS=] [default: 128]
|
||||||
|
--max-best-of <MAX_BEST_OF>
|
||||||
|
[env: MAX_BEST_OF=] [default: 2]
|
||||||
|
--max-stop-sequences <MAX_STOP_SEQUENCES>
|
||||||
|
[env: MAX_STOP_SEQUENCES=] [default: 4]
|
||||||
|
--max-top-n-tokens <MAX_TOP_N_TOKENS>
|
||||||
|
[env: MAX_TOP_N_TOKENS=] [default: 5]
|
||||||
|
--max-input-tokens <MAX_INPUT_TOKENS>
|
||||||
|
[env: MAX_INPUT_TOKENS=] [default: 1024]
|
||||||
|
--max-total-tokens <MAX_TOTAL_TOKENS>
|
||||||
|
[env: MAX_TOTAL_TOKENS=] [default: 2048]
|
||||||
|
--waiting-served-ratio <WAITING_SERVED_RATIO>
|
||||||
|
[env: WAITING_SERVED_RATIO=] [default: 1.2]
|
||||||
|
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
|
||||||
|
[env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096]
|
||||||
|
--max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>
|
||||||
|
[env: MAX_BATCH_TOTAL_TOKENS=]
|
||||||
|
--max-waiting-tokens <MAX_WAITING_TOKENS>
|
||||||
|
[env: MAX_WAITING_TOKENS=] [default: 20]
|
||||||
|
--max-batch-size <MAX_BATCH_SIZE>
|
||||||
|
[env: MAX_BATCH_SIZE=]
|
||||||
|
--hostname <HOSTNAME>
|
||||||
|
[env: HOSTNAME=] [default: 0.0.0.0]
|
||||||
|
-p, --port <PORT>
|
||||||
|
[env: PORT=] [default: 3000]
|
||||||
|
--master-shard-uds-path <MASTER_SHARD_UDS_PATH>
|
||||||
|
[env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0]
|
||||||
|
--tokenizer-name <TOKENIZER_NAME>
|
||||||
|
[env: TOKENIZER_NAME=] [default: bigscience/bloom]
|
||||||
|
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
|
||||||
|
[env: TOKENIZER_CONFIG_PATH=]
|
||||||
|
--revision <REVISION>
|
||||||
|
[env: REVISION=]
|
||||||
|
--validation-workers <VALIDATION_WORKERS>
|
||||||
|
[env: VALIDATION_WORKERS=] [default: 2]
|
||||||
|
--json-output
|
||||||
|
[env: JSON_OUTPUT=]
|
||||||
|
--otlp-endpoint <OTLP_ENDPOINT>
|
||||||
|
[env: OTLP_ENDPOINT=]
|
||||||
|
--cors-allow-origin <CORS_ALLOW_ORIGIN>
|
||||||
|
[env: CORS_ALLOW_ORIGIN=]
|
||||||
|
--ngrok
|
||||||
|
[env: NGROK=]
|
||||||
|
--ngrok-authtoken <NGROK_AUTHTOKEN>
|
||||||
|
[env: NGROK_AUTHTOKEN=]
|
||||||
|
--ngrok-edge <NGROK_EDGE>
|
||||||
|
[env: NGROK_EDGE=]
|
||||||
|
--messages-api-enabled
|
||||||
|
[env: MESSAGES_API_ENABLED=]
|
||||||
|
--disable-grammar-support
|
||||||
|
[env: DISABLE_GRAMMAR_SUPPORT=]
|
||||||
|
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
|
||||||
|
[env: MAX_CLIENT_BATCH_SIZE=] [default: 4]
|
||||||
|
-h, --help
|
||||||
|
Print help
|
||||||
|
-V, --version
|
||||||
|
Print version
|
||||||
|
```
|
||||||
|
|
||||||
|
## The Model Server
|
||||||
|
|
||||||
|
The model server is a python server, capable of starting a server waiting for gRPC requests, loads a given model, perform sharding to provide [tensor parallelism](https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism), and stays alive while waiting for new requests.
|
||||||
|
The model server supports models instantiated using Pytorch and optimized for inference mainly on CUDA/ROCM.
|
||||||
|
|
||||||
|
### Model Server Variants
|
||||||
|
|
||||||
|
Several variants of the model server exist that are actively supported by Hugging Face:
|
||||||
|
|
||||||
|
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
|
||||||
|
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
|
||||||
|
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
|
||||||
|
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
|
||||||
|
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
|
||||||
|
|
||||||
|
Not all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations.
|
||||||
|
|
||||||
|
### Command Line Interface
|
||||||
|
|
||||||
|
The official command line interface (CLI) for the server supports three subcommands, `download-weights`, `quantize` and `serve`:
|
||||||
|
|
||||||
|
- `download-weights` will download weights from the hub and, in some variants it will convert weights to a format that is adapted to the given implementation;
|
||||||
|
- `quantize` will allow to quantize a model using the `qptq` package. This feature is not available nor supported on all variants;
|
||||||
|
- `serve` will start the server that load a model (or a model shard), receives gRPC calls from the router, performs an inference and provides a formatted response to the given request.
|
||||||
|
|
||||||
|
Serve's command line parameters on the TGI repository are these:
|
||||||
|
|
||||||
|
```
|
||||||
|
Usage: cli.py serve [OPTIONS] MODEL_ID
|
||||||
|
|
||||||
|
╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ * model_id TEXT [default: None] [required] │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --revision TEXT [default: None] │
|
||||||
|
│ --sharded --no-sharded [default: no-sharded] │
|
||||||
|
│ --quantize [bitsandbytes|bitsandbytes [default: None] │
|
||||||
|
│ -nf4|bitsandbytes-fp4|gptq │
|
||||||
|
│ |awq|eetq|exl2|fp8] │
|
||||||
|
│ --speculate INTEGER [default: None] │
|
||||||
|
│ --dtype [float16|bfloat16] [default: None] │
|
||||||
|
│ --trust-remote-code --no-trust-remote-code [default: │
|
||||||
|
│ no-trust-remote-code] │
|
||||||
|
│ --uds-path PATH [default: │
|
||||||
|
│ /tmp/text-generation-serve… │
|
||||||
|
│ --logger-level TEXT [default: INFO] │
|
||||||
|
│ --json-output --no-json-output [default: no-json-output] │
|
||||||
|
│ --otlp-endpoint TEXT [default: None] │
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that some variants might support different parameters, and they could possibly accept more options that can be passed on using environment variables.
|
||||||
|
|
||||||
|
## Call Flow
|
||||||
|
|
||||||
|
Once both components are initialized, weights downloaded and model server is up and running, router and model server exchange data and info through the gRPC call. There are currently two supported schemas, [v2](https://github.com/huggingface/text-generation-inference/blob/main/proto/generate.proto) and [v3](https://github.com/huggingface/text-generation-inference/blob/main/proto/v3/generate.proto). These two versions are almost identical, except for:
|
||||||
|
|
||||||
|
- input chunks support, for text and image data,
|
||||||
|
- paged attention support
|
||||||
|
|
||||||
|
Here's a diagram that displays the exchanges that follow the router and model server startup.
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
|
||||||
|
Router->>Model Server: service discovery
|
||||||
|
Model Server-->>Router: urls for other shards
|
||||||
|
|
||||||
|
Router->>Model Server: get model info
|
||||||
|
Model Server-->>Router: shard info
|
||||||
|
|
||||||
|
Router->>Model Server: health check
|
||||||
|
Model Server-->>Router: health OK
|
||||||
|
|
||||||
|
Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size)
|
||||||
|
Model Server-->>Router: warmup result
|
||||||
|
```
|
||||||
|
|
||||||
|
After these are done, the router is ready to receive generate calls from multiple clients. Here's an example.
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
participant Client 1
|
||||||
|
participant Client 2
|
||||||
|
participant Client 3
|
||||||
|
participant Router
|
||||||
|
participant Model Server
|
||||||
|
|
||||||
|
Client 1->>Router: generate_stream
|
||||||
|
Router->>Model Server: prefill(batch1)
|
||||||
|
Model Server-->>Router: generations, cached_batch1, timings
|
||||||
|
Router-->>Client 1: token 1
|
||||||
|
|
||||||
|
Router->>Model Server: decode(cached_batch1)
|
||||||
|
Model Server-->>Router: generations, cached_batch1, timings
|
||||||
|
Router-->>Client 1: token 2
|
||||||
|
|
||||||
|
Router->>Model Server: decode(cached_batch1)
|
||||||
|
Model Server-->>Router: generations, cached_batch1, timings
|
||||||
|
Router-->>Client 1: token 3
|
||||||
|
|
||||||
|
Client 2->>Router: generate_stream
|
||||||
|
Router->>Model Server: prefill(batch2)
|
||||||
|
Note right of Model Server: This stops previous batch, that is restarted
|
||||||
|
Model Server-->>Router: generations, cached_batch2, timings
|
||||||
|
Router-->>Client 2: token 1'
|
||||||
|
|
||||||
|
Router->>Model Server: decode(cached_batch1, cached_batch2)
|
||||||
|
Model Server-->>Router: generations, cached_batch1, timings
|
||||||
|
Router-->>Client 1: token 4
|
||||||
|
Router-->>Client 2: token 2'
|
||||||
|
|
||||||
|
Note left of Client 1: Client 1 leaves
|
||||||
|
Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2)
|
||||||
|
Model Server-->>Router: filtered batch
|
||||||
|
|
||||||
|
Router->>Model Server: decode(cached_batch2)
|
||||||
|
Model Server-->>Router: generations, cached_batch2, timings
|
||||||
|
Router-->>Client 2: token 3'
|
||||||
|
|
||||||
|
Client 3->>Router: generate_stream
|
||||||
|
Note right of Model Server: This stops previous batch, that is restarted
|
||||||
|
Router->>Model Server: prefill(batch3)
|
||||||
|
Note left of Client 1: Client 3 leaves without receiving any batch
|
||||||
|
Router->>Model Server: clear_cache(batch3)
|
||||||
|
Note right of Model Server: This stops previous batch, that is restarted
|
||||||
|
|
||||||
|
Router->>Model Server: decode(cached_batch3)
|
||||||
|
Note right of Model Server: Last token (stopping criteria)
|
||||||
|
Model Server-->>Router: generations, cached_batch3, timings
|
||||||
|
Router-->>Client 2: token 4'
|
||||||
|
|
||||||
|
|
||||||
|
```
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
|||||||
--shm-size 1g \
|
--shm-size 1g \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$token \
|
-e HUGGING_FACE_HUB_TOKEN=$token \
|
||||||
-p 8080:80 \
|
-p 8080:80 \
|
||||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \
|
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
@ -62,7 +62,9 @@ Options:
|
|||||||
Possible values:
|
Possible values:
|
||||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||||
|
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
||||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||||
|
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
|
||||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||||
@ -141,7 +143,7 @@ Options:
|
|||||||
## MAX_TOP_N_TOKENS
|
## MAX_TOP_N_TOKENS
|
||||||
```shell
|
```shell
|
||||||
--max-top-n-tokens <MAX_TOP_N_TOKENS>
|
--max-top-n-tokens <MAX_TOP_N_TOKENS>
|
||||||
This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking
|
This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking
|
||||||
|
|
||||||
[env: MAX_TOP_N_TOKENS=]
|
[env: MAX_TOP_N_TOKENS=]
|
||||||
[default: 5]
|
[default: 5]
|
||||||
|
208
docs/source/basic_tutorials/train_medusa.md
Normal file
208
docs/source/basic_tutorials/train_medusa.md
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
# Train Medusa
|
||||||
|
|
||||||
|
This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation) for more information on how Medusa works and speculation in general.
|
||||||
|
|
||||||
|
## What are the benefits of training a Medusa model?
|
||||||
|
|
||||||
|
Training Medusa heads can greatly improve the speed of generation. Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training.
|
||||||
|
|
||||||
|
One of the most important things is to have a good dataset (with similar data to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain.
|
||||||
|
|
||||||
|
If you train Medusa on a dataset that is very different from the one you will use in production then the model will not be able to predict the future tokens accurately and consequently the speedup will be minimal or non-existent.
|
||||||
|
|
||||||
|
## Self-distillation (Generating data for training)
|
||||||
|
|
||||||
|
There are many methods for preparing data for training, but one of the easiest and most effective ways is to "self-distill" the data. This means that you can use the same model to generate the data that you will use to train the model.
|
||||||
|
|
||||||
|
Essentially, you prompt the model with a similar input to what you will use in production and the model will generate the output.
|
||||||
|
|
||||||
|
We'll use this output to help train the medusa heads to predict the `n+1`, `n+2`, `n+3`, etc tokens in the sequence.
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
The original implementation of Medusa is available at [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) and we'll follow a very similar process to train the model as described on the original repository.
|
||||||
|
|
||||||
|
### Getting Started
|
||||||
|
|
||||||
|
There are two methods for training the model:
|
||||||
|
|
||||||
|
- `torchrun` that is a wrapper around `torch.distributed.launch`
|
||||||
|
- a forked version of `axlotl` that supports Medusa
|
||||||
|
|
||||||
|
In this tutorial we'll use `torchrun` to train the model as it is the most straightforward way to train the model but similar steps can be followed to train the model using `axlotl` if you prefer.
|
||||||
|
|
||||||
|
### Training with `torchrun`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir medusa-training
|
||||||
|
cd medusa-training
|
||||||
|
|
||||||
|
pyenv install 3.10
|
||||||
|
pyenv local 3.10
|
||||||
|
|
||||||
|
uv venv -p 3.10
|
||||||
|
source .venv/bin/activate
|
||||||
|
```
|
||||||
|
|
||||||
|
Now lets clone the original `Medusa` repository and install the library.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/FasterDecoding/Medusa.git
|
||||||
|
cd Medusa
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
Next we'll need some data to train on, we can use the `ShareGPT_Vicuna_unfiltered` dataset that is available on the Hugging Face Hub.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
apt install git-lfs
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
|
||||||
|
```
|
||||||
|
|
||||||
|
Currently our directory structure looks like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
.
|
||||||
|
├── assets
|
||||||
|
├── CITATION.cff
|
||||||
|
├── create_data.py
|
||||||
|
├── data_generation
|
||||||
|
├── deepspeed.json
|
||||||
|
├── last_run_prepared
|
||||||
|
├── LICENSE
|
||||||
|
├── llm_judge
|
||||||
|
├── medusa
|
||||||
|
├── medusa_llm.egg-info
|
||||||
|
├── mistral.json
|
||||||
|
├── notebooks
|
||||||
|
├── pyproject.toml
|
||||||
|
├── README.md
|
||||||
|
├── ROADMAP.md
|
||||||
|
├── scripts
|
||||||
|
├── ShareGPT_Vicuna_unfiltered
|
||||||
|
│ ├── README.md
|
||||||
|
│ ├── ShareGPT_2023.05.04v0_Wasteland_Edition.json
|
||||||
|
│ └── ShareGPT_V4.3_unfiltered_cleaned_split.json
|
||||||
|
├── simple_gradio_interface.py
|
||||||
|
├── tiny-llama.json
|
||||||
|
└── vicuna_7b_qlora_stage1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Start Training
|
||||||
|
|
||||||
|
Now the lets generate the data and start training the model. This process will take a while since we are generating data from the model.
|
||||||
|
|
||||||
|
First make sure you have an instance of TGI running with the model you want to use for self-distillation.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=HuggingFaceH4/zephyr-7b-beta
|
||||||
|
volume=/home/ubuntu/.cache/huggingface/hub/
|
||||||
|
|
||||||
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model
|
||||||
|
```
|
||||||
|
|
||||||
|
Now we can generate the data using the `create_data.py` script.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python create_data.py \
|
||||||
|
--input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
|
||||||
|
--output-filename zephyr_self_distill.json
|
||||||
|
```
|
||||||
|
|
||||||
|
At this point our terminal should look like this:
|
||||||
|
|
||||||
|
<div class="flex justify-center">
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-large.gif"
|
||||||
|
width="550"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
> Note: In the screen shot above we are only using a the first 500 examples from the dataset to speed up the process, you should have a much larger dataset for training.
|
||||||
|
|
||||||
|
Now we can finally get to the fun part and start training the model!
|
||||||
|
|
||||||
|
Using `torchrun` we can easily launch the `medusa` training script with the `zephyr_self_distill.json` configuration file.
|
||||||
|
|
||||||
|
> NOTE: If you just self-distilled you may still have the model running, make sure to stop it before starting the training in order to allow all of the resources to be used for training.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
WANDB_MODE=offline torchrun --nproc_per_node=4 medusa/train/train_legacy.py \
|
||||||
|
--model_name_or_path HuggingFaceH4/zephyr-7b-beta \
|
||||||
|
--data_path zephyr_self_distill.json \
|
||||||
|
--bf16 True \
|
||||||
|
--output_dir zephyr_out \
|
||||||
|
--num_train_epochs 5 \
|
||||||
|
--per_device_train_batch_size 4 \
|
||||||
|
--per_device_eval_batch_size 4 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--evaluation_strategy "no" \
|
||||||
|
--save_strategy "no" \
|
||||||
|
--learning_rate 1e-3 \
|
||||||
|
--weight_decay 0.0 \
|
||||||
|
--warmup_ratio 0.1 \
|
||||||
|
--lr_scheduler_type "cosine" \
|
||||||
|
--logging_steps 1 \
|
||||||
|
--tf32 True \
|
||||||
|
--model_max_length 2048 \
|
||||||
|
--lazy_preprocess True \
|
||||||
|
--medusa_num_heads 3 \
|
||||||
|
--medusa_num_layers 1 \
|
||||||
|
--deepspeed deepspeed.json
|
||||||
|
```
|
||||||
|
|
||||||
|
<div class="flex justify-center">
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-heads-large.gif"
|
||||||
|
width="550"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
If successful, you should see the similar output to the one below:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wandb: Run history:
|
||||||
|
wandb: train/epoch ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
|
||||||
|
wandb: train/global_step ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
|
||||||
|
wandb: train/learning_rate ▅███▇▇▆▅▅▄▃▂▂▁▁▁
|
||||||
|
wandb: train/loss ██▆▄▄▃▃▂▂▃▁▁▂▁▁▁
|
||||||
|
wandb: train/medusa0_loss ▆▆▇▆▆▅▄▅▃▃▃▃▂▂▂▂▂▃▂▂▂▁▁▁▂▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
|
||||||
|
wandb: train/medusa0_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▄▄▄▃▄▃▄▄▅▅▆▅▆▆▇▅▇▇▄▇█▇▅▇█▆▇▇
|
||||||
|
wandb: train/medusa1_loss ▇▇█▇▇▆▅▅▃▄▃▃▃▃▃▃▃▃▃▃▂▁▂▂▂▁▁▂▁▁▇▁▁▁▂▁▁▁▁▁
|
||||||
|
wandb: train/medusa1_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▃▄▄▃▃▂▃▃▅▅▆▄█▆▇▅▇▇▅█▇▇▅▇█▆▆▇
|
||||||
|
wandb: train/medusa2_loss ▃▃▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
|
||||||
|
wandb: train/medusa2_top1 ▁▁▁▂▁▁▁▁▂▂▃▃▃▄▄▃▃▂▃▃▅▆▅▄█▆▆▅▆▆▄█▇▇▄▇█▆▆▇
|
||||||
|
wandb: train/total_flos ▁
|
||||||
|
wandb: train/train_loss ▁
|
||||||
|
wandb: train/train_runtime ▁
|
||||||
|
wandb: train/train_samples_per_second ▁
|
||||||
|
wandb: train/train_steps_per_second ▁
|
||||||
|
wandb:
|
||||||
|
wandb: Run summary:
|
||||||
|
wandb: train/epoch 2.0
|
||||||
|
wandb: train/global_step 16
|
||||||
|
wandb: train/learning_rate 0.0
|
||||||
|
wandb: train/loss 14.8906
|
||||||
|
wandb: train/medusa0_loss 4.25
|
||||||
|
wandb: train/medusa0_top1 0.28809
|
||||||
|
wandb: train/medusa1_loss 4.8125
|
||||||
|
wandb: train/medusa1_top1 0.22727
|
||||||
|
wandb: train/medusa2_loss 5.5
|
||||||
|
wandb: train/medusa2_top1 0.17293
|
||||||
|
wandb: train/total_flos 0.0
|
||||||
|
wandb: train/train_loss 23.98242
|
||||||
|
wandb: train/train_runtime 396.9266
|
||||||
|
wandb: train/train_samples_per_second 2.519
|
||||||
|
wandb: train/train_steps_per_second 0.04
|
||||||
|
```
|
||||||
|
|
||||||
|
Last but most importantly, don't forget to push this model to the Hugging Face Hub so you can use it in your projects.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m medusa.hf_utils \
|
||||||
|
--folder zephyr_out_medusa_mlp_zephyr-7b-beta_medusa_3_lr_0.001_layers_1 \
|
||||||
|
--repo drbh/zephyr_medusa_demo
|
||||||
|
```
|
||||||
|
|
||||||
|
Woo, we've successfully trained a Medusa model and pushed it to the Hugging Face Hub! 🎉
|
@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
## What is Guidance?
|
## What is Guidance?
|
||||||
|
|
||||||
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format.
|
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.
|
||||||
|
|
||||||
## How is it used?
|
## How is it used?
|
||||||
|
|
||||||
Guidance can be in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:
|
Guidance can be implemented in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:
|
||||||
|
|
||||||
Technically, guidance can be used to generate:
|
Technically, guidance can be used to generate:
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ You can check a few existing fine-tunes for popular models:
|
|||||||
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
|
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
|
||||||
|
|
||||||
|
|
||||||
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa)
|
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md)
|
||||||
|
|
||||||
|
|
||||||
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
|
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
|
||||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.0.3-rocm \
|
ghcr.io/huggingface/text-generation-inference:2.0.4-rocm \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you wo
|
|||||||
|
|
||||||
## Flash attention implementation
|
## Flash attention implementation
|
||||||
|
|
||||||
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/flash_attn_triton.py).
|
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py).
|
||||||
|
|
||||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.0.3 \
|
ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.0.3 \
|
ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \
|
|||||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run ghcr.io/huggingface/text-generation-inference:2.0.3 --help
|
docker run ghcr.io/huggingface/text-generation-inference:2.0.4 --help
|
||||||
```
|
```
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
@ -20,7 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||||||
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
||||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
||||||
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
|
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
|
||||||
- [Qwen 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
|
- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
|
||||||
- [Opt](https://huggingface.co/facebook/opt-6.7b)
|
- [Opt](https://huggingface.co/facebook/opt-6.7b)
|
||||||
- [T5](https://huggingface.co/google/flan-t5-xxl)
|
- [T5](https://huggingface.co/google/flan-t5-xxl)
|
||||||
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
||||||
|
@ -7,9 +7,10 @@ import os
|
|||||||
import docker
|
import docker
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
|
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
@ -37,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
|||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
rtol = 0.2
|
rtol = 0.2
|
||||||
|
ignore_logprob = False
|
||||||
|
|
||||||
def serialize(
|
def serialize(
|
||||||
self,
|
self,
|
||||||
@ -94,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
return (
|
return (
|
||||||
token.id == other.id
|
token.id == other.id
|
||||||
and token.text == other.text
|
and token.text == other.text
|
||||||
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
|
and (
|
||||||
|
self.ignore_logprob
|
||||||
|
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
|
||||||
|
)
|
||||||
and token.special == other.special
|
and token.special == other.special
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
prefill_token.id == other.id
|
prefill_token.id == other.id
|
||||||
and prefill_token.text == other.text
|
and prefill_token.text == other.text
|
||||||
and (
|
and (
|
||||||
math.isclose(
|
self.ignore_logprob
|
||||||
prefill_token.logprob, other.logprob, rel_tol=self.rtol
|
or math.isclose(
|
||||||
|
prefill_token.logprob,
|
||||||
|
other.logprob,
|
||||||
|
rel_tol=self.rtol,
|
||||||
)
|
)
|
||||||
if prefill_token.logprob is not None
|
if prefill_token.logprob is not None
|
||||||
else prefill_token.logprob == other.logprob
|
else prefill_token.logprob == other.logprob
|
||||||
@ -222,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator):
|
|||||||
rtol = 0.75
|
rtol = 0.75
|
||||||
|
|
||||||
|
|
||||||
|
class IgnoreLogProbResponseComparator(ResponseComparator):
|
||||||
|
ignore_logprob = True
|
||||||
|
|
||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}")
|
self.client = AsyncClient(f"http://localhost:{port}")
|
||||||
@ -273,6 +285,11 @@ def generous_response_snapshot(snapshot):
|
|||||||
return snapshot.use_extension(GenerousResponseComparator)
|
return snapshot.use_extension(GenerousResponseComparator)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ignore_logprob_response_snapshot(snapshot):
|
||||||
|
return snapshot.use_extension(IgnoreLogProbResponseComparator)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def event_loop():
|
def event_loop():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -347,19 +364,22 @@ def launcher(event_loop):
|
|||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
|
||||||
|
with tempfile.TemporaryFile("w+") as tmp:
|
||||||
|
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||||
|
# cause the process to block until stdout is read.
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
args,
|
||||||
|
stdout=tmp,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
env=env,
|
||||||
) as process:
|
) as process:
|
||||||
yield ProcessLauncherHandle(process, port)
|
yield ProcessLauncherHandle(process, port)
|
||||||
|
|
||||||
process.terminate()
|
process.terminate()
|
||||||
process.wait(60)
|
process.wait(60)
|
||||||
|
|
||||||
launcher_output = process.stdout.read().decode("utf-8")
|
tmp.seek(0)
|
||||||
print(launcher_output, file=sys.stderr)
|
shutil.copyfileobj(tmp, sys.stderr)
|
||||||
|
|
||||||
process.stdout.close()
|
|
||||||
process.stderr.close()
|
|
||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
del env["USE_FLASH_ATTENTION"]
|
del env["USE_FLASH_ATTENTION"]
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
@ -13,14 +13,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1712874856,
|
"created": 1716553098,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native",
|
"system_fingerprint": "2.0.5-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 100,
|
"completion_tokens": 100,
|
||||||
"prompt_tokens": 60,
|
"prompt_tokens": 62,
|
||||||
"total_tokens": 160
|
"total_tokens": 162
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.640625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4296875,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8632812,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1328125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76660156,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3837891,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9746094,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4189453,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.34375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8852539,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.65625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.3671875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -0.36938477,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.8046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -0.46240234,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": -1.7460938,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235265,
|
||||||
|
"logprob": -1.9443359,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": -1.4550781,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235308,
|
||||||
|
"logprob": -1.0205078,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235290,
|
||||||
|
"logprob": -1.0283203,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -1.2783203,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235284,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request for 12.25-12"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.359375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4277344,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4394531,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8613281,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1523438,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76220703,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3642578,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -2.0175781,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4238281,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.328125,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8881836,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4238281,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.7631836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3642578,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9960938,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4179688,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.640625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.3671875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4257812,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1367188,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.76171875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3515625,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9873047,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4169922,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3320312,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.8930664,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2015,
|
||||||
|
"logprob": -9.6484375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3853,
|
||||||
|
"logprob": -10.359375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -2.4179688,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -2.4492188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2412,
|
||||||
|
"logprob": -2.8574219,
|
||||||
|
"special": false,
|
||||||
|
"text": " following"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 109,
|
||||||
|
"logprob": -0.7519531,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235287,
|
||||||
|
"logprob": -1.3623047,
|
||||||
|
"special": false,
|
||||||
|
"text": "*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.9707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 199,
|
||||||
|
"logprob": -1.4267578,
|
||||||
|
"special": false,
|
||||||
|
"text": "<strong>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1232,
|
||||||
|
"logprob": -4.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 208,
|
||||||
|
"logprob": -0.88427734,
|
||||||
|
"special": false,
|
||||||
|
"text": "</strong>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the following:\n\n* <strong>Name</strong>"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,84 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.4375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -3.5136719,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -0.7783203,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -1.2314453,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -2.0019531,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2990,
|
||||||
|
"logprob": -1.5009766,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\\"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 77,
|
||||||
|
"logprob": -0.057434082,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 702,
|
||||||
|
"logprob": -1.4912109,
|
||||||
|
"special": false,
|
||||||
|
"text": "\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.2636719,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 557,
|
||||||
|
"logprob": -2.4042969,
|
||||||
|
"special": false,
|
||||||
|
"text": " }\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||||
|
}
|
@ -0,0 +1,84 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.453125,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.9980469,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 578,
|
||||||
|
"logprob": -0.15795898,
|
||||||
|
"special": false,
|
||||||
|
"text": " The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3622,
|
||||||
|
"logprob": -1.0458984,
|
||||||
|
"special": false,
|
||||||
|
"text": " server"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 31680,
|
||||||
|
"logprob": -1.3623047,
|
||||||
|
"special": false,
|
||||||
|
"text": " responds"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 449,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " with"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -0.5678711,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1049,
|
||||||
|
"logprob": -0.12322998,
|
||||||
|
"special": false,
|
||||||
|
"text": "200"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10619,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " OK"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request. The server responds with a \"200 OK\""
|
||||||
|
}
|
@ -0,0 +1,338 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.453125,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.9785156,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -3.4941406,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -0.79345703,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -1.2324219,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -1.9794922,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2990,
|
||||||
|
"logprob": -1.4892578,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\\"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 77,
|
||||||
|
"logprob": -0.058258057,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 702,
|
||||||
|
"logprob": -1.4892578,
|
||||||
|
"special": false,
|
||||||
|
"text": "\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.2783203,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 557,
|
||||||
|
"logprob": -2.3945312,
|
||||||
|
"special": false,
|
||||||
|
"text": " }\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.40625,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.9433594,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -3.4726562,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -0.8022461,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -1.2509766,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -1.984375,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2990,
|
||||||
|
"logprob": -1.4677734,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\\"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 77,
|
||||||
|
"logprob": -0.059173584,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 702,
|
||||||
|
"logprob": -1.4990234,
|
||||||
|
"special": false,
|
||||||
|
"text": "\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.2822266,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 557,
|
||||||
|
"logprob": -2.3867188,
|
||||||
|
"special": false,
|
||||||
|
"text": " }\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.421875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.9511719,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -3.46875,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -0.77490234,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -1.2558594,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -1.984375,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2990,
|
||||||
|
"logprob": -1.4990234,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\\"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 77,
|
||||||
|
"logprob": -0.059143066,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 702,
|
||||||
|
"logprob": -1.4941406,
|
||||||
|
"special": false,
|
||||||
|
"text": "\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.2578125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 557,
|
||||||
|
"logprob": -2.3964844,
|
||||||
|
"special": false,
|
||||||
|
"text": " }\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.4140625,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.9101562,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -3.5039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -0.8076172,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -1.2236328,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"logprob": -1.9853516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2990,
|
||||||
|
"logprob": -1.4892578,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\\"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 77,
|
||||||
|
"logprob": -0.056671143,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 702,
|
||||||
|
"logprob": -1.5107422,
|
||||||
|
"special": false,
|
||||||
|
"text": "\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.2597656,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 557,
|
||||||
|
"logprob": -2.4042969,
|
||||||
|
"special": false,
|
||||||
|
"text": " }\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,84 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.6230469,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3270,
|
||||||
|
"logprob": -2.046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1425781,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.9238281,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13204,
|
||||||
|
"logprob": -0.076660156,
|
||||||
|
"special": false,
|
||||||
|
"text": ".method"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 624,
|
||||||
|
"logprob": -0.021987915,
|
||||||
|
"special": false,
|
||||||
|
"text": " =="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 364,
|
||||||
|
"logprob": -0.39208984,
|
||||||
|
"special": false,
|
||||||
|
"text": " '"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3019,
|
||||||
|
"logprob": -0.10821533,
|
||||||
|
"special": false,
|
||||||
|
"text": "POST"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
|
}
|
@ -0,0 +1,84 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.2539062,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 578,
|
||||||
|
"logprob": -0.15563965,
|
||||||
|
"special": false,
|
||||||
|
"text": " The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3622,
|
||||||
|
"logprob": -0.8203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " server"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 706,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " has"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 539,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " not"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3686,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " yet"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3288,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " sent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 904,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " any"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 828,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 382,
|
||||||
|
"logprob": -1.5517578,
|
||||||
|
"special": false,
|
||||||
|
"text": ".\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request. The server has not yet sent any data.\n\n"
|
||||||
|
}
|
@ -0,0 +1,338 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.6220703,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3270,
|
||||||
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13204,
|
||||||
|
"logprob": -0.07672119,
|
||||||
|
"special": false,
|
||||||
|
"text": ".method"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 624,
|
||||||
|
"logprob": -0.021987915,
|
||||||
|
"special": false,
|
||||||
|
"text": " =="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 364,
|
||||||
|
"logprob": -0.39208984,
|
||||||
|
"special": false,
|
||||||
|
"text": " '"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3019,
|
||||||
|
"logprob": -0.10638428,
|
||||||
|
"special": false,
|
||||||
|
"text": "POST"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.6220703,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3270,
|
||||||
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13204,
|
||||||
|
"logprob": -0.07672119,
|
||||||
|
"special": false,
|
||||||
|
"text": ".method"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 624,
|
||||||
|
"logprob": -0.021987915,
|
||||||
|
"special": false,
|
||||||
|
"text": " =="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 364,
|
||||||
|
"logprob": -0.39208984,
|
||||||
|
"special": false,
|
||||||
|
"text": " '"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3019,
|
||||||
|
"logprob": -0.10638428,
|
||||||
|
"special": false,
|
||||||
|
"text": "POST"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.6220703,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3270,
|
||||||
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13204,
|
||||||
|
"logprob": -0.07672119,
|
||||||
|
"special": false,
|
||||||
|
"text": ".method"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 624,
|
||||||
|
"logprob": -0.021987915,
|
||||||
|
"special": false,
|
||||||
|
"text": " =="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 364,
|
||||||
|
"logprob": -0.39208984,
|
||||||
|
"special": false,
|
||||||
|
"text": " '"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3019,
|
||||||
|
"logprob": -0.10638428,
|
||||||
|
"special": false,
|
||||||
|
"text": "POST"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -11.34375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -1.6220703,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3270,
|
||||||
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13204,
|
||||||
|
"logprob": -0.07672119,
|
||||||
|
"special": false,
|
||||||
|
"text": ".method"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 624,
|
||||||
|
"logprob": -0.021987915,
|
||||||
|
"special": false,
|
||||||
|
"text": " =="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 364,
|
||||||
|
"logprob": -0.39208984,
|
||||||
|
"special": false,
|
||||||
|
"text": " '"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3019,
|
||||||
|
"logprob": -0.10638428,
|
||||||
|
"special": false,
|
||||||
|
"text": "POST"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -12.390625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.0625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.0507812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3007812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29902,
|
||||||
|
"logprob": -2.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 505,
|
||||||
|
"logprob": -1.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " have"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.2076416,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1243,
|
||||||
|
"logprob": -2.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.6845703,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 515,
|
||||||
|
"logprob": -1.1748047,
|
||||||
|
"special": false,
|
||||||
|
"text": " from"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -1.0644531,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1404,
|
||||||
|
"logprob": -1.5224609,
|
||||||
|
"special": false,
|
||||||
|
"text": " user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nI have a test request from a user"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -12.390625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.0625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 5229,
|
||||||
|
"logprob": -1.2607422,
|
||||||
|
"special": false,
|
||||||
|
"text": " failed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6527,
|
||||||
|
"logprob": -0.11450195,
|
||||||
|
"special": false,
|
||||||
|
"text": " Could"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 451,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " not"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4511,
|
||||||
|
"logprob": -0.2286377,
|
||||||
|
"special": false,
|
||||||
|
"text": " connect"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1923,
|
||||||
|
"logprob": -1.2568359,
|
||||||
|
"special": false,
|
||||||
|
"text": " server"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.15905762,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29902,
|
||||||
|
"logprob": -0.21618652,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request failed: Could not connect to server\n\nI"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -12.390625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.0625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.0507812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3007812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29902,
|
||||||
|
"logprob": -2.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 505,
|
||||||
|
"logprob": -1.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " have"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.2076416,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1243,
|
||||||
|
"logprob": -2.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.6845703,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 515,
|
||||||
|
"logprob": -1.1748047,
|
||||||
|
"special": false,
|
||||||
|
"text": " from"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -1.0595703,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1404,
|
||||||
|
"logprob": -1.5224609,
|
||||||
|
"special": false,
|
||||||
|
"text": " user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nI have a test request from a user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -12.390625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.0625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.0507812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3007812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29902,
|
||||||
|
"logprob": -2.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 505,
|
||||||
|
"logprob": -1.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " have"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.2076416,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1243,
|
||||||
|
"logprob": -2.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.6845703,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 515,
|
||||||
|
"logprob": -1.1748047,
|
||||||
|
"special": false,
|
||||||
|
"text": " from"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -1.0595703,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1404,
|
||||||
|
"logprob": -1.5224609,
|
||||||
|
"special": false,
|
||||||
|
"text": " user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nI have a test request from a user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -12.390625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.0625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.0507812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3007812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29902,
|
||||||
|
"logprob": -2.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 505,
|
||||||
|
"logprob": -1.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " have"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.2076416,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1243,
|
||||||
|
"logprob": -2.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.6845703,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 515,
|
||||||
|
"logprob": -1.1748047,
|
||||||
|
"special": false,
|
||||||
|
"text": " from"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -1.0595703,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1404,
|
||||||
|
"logprob": -1.5224609,
|
||||||
|
"special": false,
|
||||||
|
"text": " user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nI have a test request from a user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -12.390625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.0625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.0507812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3007812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29902,
|
||||||
|
"logprob": -2.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 505,
|
||||||
|
"logprob": -1.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " have"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.2076416,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1243,
|
||||||
|
"logprob": -2.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.6845703,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 515,
|
||||||
|
"logprob": -1.1748047,
|
||||||
|
"special": false,
|
||||||
|
"text": " from"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -1.0595703,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1404,
|
||||||
|
"logprob": -1.5224609,
|
||||||
|
"special": false,
|
||||||
|
"text": " user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nI have a test request from a user"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 8,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2502,
|
||||||
|
"logprob": -1.734375,
|
||||||
|
"special": false,
|
||||||
|
"text": "image"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2196,
|
||||||
|
"logprob": -0.5756836,
|
||||||
|
"special": false,
|
||||||
|
"text": " result"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -0.007843018,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12254,
|
||||||
|
"logprob": -1.7167969,
|
||||||
|
"special": false,
|
||||||
|
"text": " chicken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 611,
|
||||||
|
"logprob": -0.17053223,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.7626953,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8318,
|
||||||
|
"logprob": -0.02709961,
|
||||||
|
"special": false,
|
||||||
|
"text": " beach"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.20739746,
|
||||||
|
"special": true,
|
||||||
|
"text": "<eos>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "image result for chicken on the beach"
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1718044128,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.5-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 39,
|
||||||
|
"prompt_tokens": 136,
|
||||||
|
"total_tokens": 175
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,85 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 12,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 450,
|
||||||
|
"logprob": -0.26342773,
|
||||||
|
"special": false,
|
||||||
|
"text": " The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21282,
|
||||||
|
"logprob": -0.01838684,
|
||||||
|
"special": false,
|
||||||
|
"text": " cow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 322,
|
||||||
|
"logprob": -0.18041992,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 521,
|
||||||
|
"logprob": -0.62841797,
|
||||||
|
"special": false,
|
||||||
|
"text": " ch"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21475,
|
||||||
|
"logprob": -0.0037956238,
|
||||||
|
"special": false,
|
||||||
|
"text": "icken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 526,
|
||||||
|
"logprob": -0.018737793,
|
||||||
|
"special": false,
|
||||||
|
"text": " are"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 373,
|
||||||
|
"logprob": -1.0820312,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.5083008,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25695,
|
||||||
|
"logprob": -0.07128906,
|
||||||
|
"special": false,
|
||||||
|
"text": " beach"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.12573242,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32002,
|
||||||
|
"logprob": -0.0029792786,
|
||||||
|
"special": true,
|
||||||
|
"text": "<end_of_utterance>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": -0.00024962425,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " The cow and chicken are on a beach."
|
||||||
|
}
|
@ -0,0 +1,133 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 20,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 415,
|
||||||
|
"logprob": -0.04421997,
|
||||||
|
"special": false,
|
||||||
|
"text": " The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12072,
|
||||||
|
"logprob": -0.13500977,
|
||||||
|
"special": false,
|
||||||
|
"text": " cow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.06750488,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6328,
|
||||||
|
"logprob": -0.6352539,
|
||||||
|
"special": false,
|
||||||
|
"text": " standing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 356,
|
||||||
|
"logprob": -0.16186523,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 272,
|
||||||
|
"logprob": -0.5078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10305,
|
||||||
|
"logprob": -0.017913818,
|
||||||
|
"special": false,
|
||||||
|
"text": " beach"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": -1.5205078,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 272,
|
||||||
|
"logprob": -0.029174805,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13088,
|
||||||
|
"logprob": -0.003479004,
|
||||||
|
"special": false,
|
||||||
|
"text": " chicken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.0035095215,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6398,
|
||||||
|
"logprob": -0.3088379,
|
||||||
|
"special": false,
|
||||||
|
"text": " sitting"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 356,
|
||||||
|
"logprob": -0.027755737,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.31884766,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17972,
|
||||||
|
"logprob": -0.047943115,
|
||||||
|
"special": false,
|
||||||
|
"text": " pile"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 302,
|
||||||
|
"logprob": -0.0002925396,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2445,
|
||||||
|
"logprob": -0.02935791,
|
||||||
|
"special": false,
|
||||||
|
"text": " money"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28723,
|
||||||
|
"logprob": -0.031219482,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32002,
|
||||||
|
"logprob": -0.00034475327,
|
||||||
|
"special": true,
|
||||||
|
"text": "<end_of_utterance>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": -1.1920929e-07,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||||
|
}
|
@ -35,8 +35,9 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(repr(response.choices[0].message.content))
|
||||||
assert (
|
assert (
|
||||||
response.choices[0].message.content
|
response.choices[0].message.content
|
||||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas"
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -3,7 +3,7 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_gemma_handle(launcher):
|
def flash_gemma_handle(launcher):
|
||||||
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
|
with launcher("google/gemma-2b", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle):
|
|||||||
return flash_gemma_handle.client
|
return flash_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||||
@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
||||||
@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
||||||
|
64
integration-tests/models/test_flash_gemma_gptq.py
Normal file
64
integration-tests/models/test_flash_gemma_gptq.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_gemma_gptq_handle(launcher):
|
||||||
|
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
||||||
|
await flash_gemma_gptq_handle.health(300)
|
||||||
|
return flash_gemma_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
|
||||||
|
response = await flash_gemma_gptq.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq_all_params(
|
||||||
|
flash_gemma_gptq, ignore_logprob_response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_gemma_gptq.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma_gptq_load(
|
||||||
|
flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == ignore_logprob_response_snapshot
|
73
integration-tests/models/test_flash_llama_exl2.py
Normal file
73
integration-tests/models/test_flash_llama_exl2.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_exl2_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"turboderp/Llama-3-8B-Instruct-exl2",
|
||||||
|
revision="2.5bpw",
|
||||||
|
# Set max input length to avoid OOM due to extremely large
|
||||||
|
# scratch buffer.
|
||||||
|
max_input_length=1024,
|
||||||
|
num_shard=1,
|
||||||
|
quantize="exl2",
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_exl2(flash_llama_exl2_handle):
|
||||||
|
await flash_llama_exl2_handle.health(300)
|
||||||
|
return flash_llama_exl2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||||
|
response = await flash_llama_exl2.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_exl2_all_params(
|
||||||
|
flash_llama_exl2, ignore_logprob_response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_exl2.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.generated_text == 'Test request. The server responds with a "200 OK"'
|
||||||
|
)
|
||||||
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_exl2_load(
|
||||||
|
flash_llama_exl2, generate_load, ignore_logprob_response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_exl2, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == ignore_logprob_response_snapshot
|
65
integration-tests/models/test_flash_llama_gptq_marlin.py
Normal file
65
integration-tests/models/test_flash_llama_gptq_marlin.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_gptq_marlin_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
||||||
|
await flash_llama_gptq_marlin_handle.health(300)
|
||||||
|
return flash_llama_gptq_marlin_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
||||||
|
response = await flash_llama_gptq_marlin.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_gptq_marlin_all_params(
|
||||||
|
flash_llama_gptq_marlin, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_gptq_marlin.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_gptq_marlin_load(
|
||||||
|
flash_llama_gptq_marlin, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
63
integration-tests/models/test_flash_llama_marlin.py
Normal file
63
integration-tests/models/test_flash_llama_marlin.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_marlin_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin"
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_marlin(flash_llama_marlin_handle):
|
||||||
|
await flash_llama_marlin_handle.health(300)
|
||||||
|
return flash_llama_marlin_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
||||||
|
response = await flash_llama_marlin.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
|
||||||
|
response = await flash_llama_marlin.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_marlin_load(
|
||||||
|
flash_llama_marlin, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
|
|||||||
return flash_pali_gemma_handle.client
|
return flash_pali_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
def get_cow_beach():
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
|||||||
|
|
||||||
assert response.generated_text == "beach"
|
assert response.generated_text == "beach"
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
cow_beach = get_cow_beach()
|
||||||
|
response = await flash_pali_gemma.generate(
|
||||||
|
f"caption\n",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
# Is PaliGemma not able to handle two separate images? At least we
|
||||||
|
# get output showing that both images are used.
|
||||||
|
assert (
|
||||||
|
response.generated_text == "image result for chicken on the beach"
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response == response_snapshot
|
||||||
|
101
integration-tests/models/test_grammar_response_format_llama.py
Normal file
101
integration-tests/models/test_grammar_response_format_llama.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
num_shard=1,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
use_flash_attention=False,
|
||||||
|
max_batch_prefill_tokens=3000,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def llama_grammar(llama_grammar_handle):
|
||||||
|
await llama_grammar_handle.health(300)
|
||||||
|
return llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = response.json()
|
||||||
|
called = chat_completion["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert (
|
||||||
|
called
|
||||||
|
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
|
||||||
|
)
|
||||||
|
assert chat_completion == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||||
|
llama_grammar,
|
||||||
|
):
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"tools": [],
|
||||||
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 422 means the server was unable to process the request because it contains invalid data.
|
||||||
|
assert response.status_code == 422
|
||||||
|
assert response.json() == {
|
||||||
|
"error": "Grammar and tools are mutually exclusive",
|
||||||
|
"error_type": "grammar and tools",
|
||||||
|
}
|
@ -23,6 +23,12 @@ def get_chicken():
|
|||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics(idefics, response_snapshot):
|
async def test_idefics(idefics, response_snapshot):
|
||||||
chicken = get_chicken()
|
chicken = get_chicken()
|
||||||
@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_idefics_two_images(idefics, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
cow_beach = get_cow_beach()
|
||||||
|
response = await idefics.generate(
|
||||||
|
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text == " The cow and chicken are on a beach."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||||
chicken = get_chicken()
|
chicken = get_chicken()
|
||||||
|
@ -9,6 +9,12 @@ def get_chicken():
|
|||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_idefics2_next_handle(launcher):
|
def flash_idefics2_next_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
cow_beach = get_cow_beach()
|
||||||
|
response = await flash_idefics2_next.generate(
|
||||||
|
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 20
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
|
||||||
|
@ -17,14 +17,32 @@ use std::thread::sleep;
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{fs, io};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct RawConfig {
|
||||||
|
max_position_embeddings: Option<usize>,
|
||||||
|
n_positions: Option<usize>,
|
||||||
|
max_seq_len: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
max_seq_len: Option<usize>,
|
}
|
||||||
|
|
||||||
|
impl From<RawConfig> for Config {
|
||||||
|
fn from(other: RawConfig) -> Self {
|
||||||
|
let max_position_embeddings = other
|
||||||
|
.max_position_embeddings
|
||||||
|
.or(other.max_seq_len)
|
||||||
|
.or(other.n_positions);
|
||||||
|
Config {
|
||||||
|
max_position_embeddings,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
@ -37,11 +55,17 @@ enum Quantization {
|
|||||||
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
||||||
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||||
Eetq,
|
Eetq,
|
||||||
|
/// Variable bit quantization. Requires a specific EXL2 quantized model:
|
||||||
|
/// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
|
||||||
|
/// not support tensor parallelism (num_shard > 1).
|
||||||
|
Exl2,
|
||||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
|
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
|
||||||
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
|
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
|
||||||
/// triton kernel (wider support) when it's not.
|
/// triton kernel (wider support) when it's not.
|
||||||
/// AWQ has faster kernels.
|
/// AWQ has faster kernels.
|
||||||
Gptq,
|
Gptq,
|
||||||
|
/// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
|
||||||
|
Marlin,
|
||||||
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
||||||
/// but it is known that the model will be much slower to run than the native f16.
|
/// but it is known that the model will be much slower to run than the native f16.
|
||||||
#[deprecated(
|
#[deprecated(
|
||||||
@ -77,9 +101,15 @@ impl std::fmt::Display for Quantization {
|
|||||||
Quantization::BitsandbytesFP4 => {
|
Quantization::BitsandbytesFP4 => {
|
||||||
write!(f, "bitsandbytes-fp4")
|
write!(f, "bitsandbytes-fp4")
|
||||||
}
|
}
|
||||||
|
Quantization::Exl2 => {
|
||||||
|
write!(f, "exl2")
|
||||||
|
}
|
||||||
Quantization::Gptq => {
|
Quantization::Gptq => {
|
||||||
write!(f, "gptq")
|
write!(f, "gptq")
|
||||||
}
|
}
|
||||||
|
Quantization::Marlin => {
|
||||||
|
write!(f, "marlin")
|
||||||
|
}
|
||||||
Quantization::Awq => {
|
Quantization::Awq => {
|
||||||
write!(f, "awq")
|
write!(f, "awq")
|
||||||
}
|
}
|
||||||
@ -228,7 +258,7 @@ struct Args {
|
|||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
|
||||||
/// This is the maximum allowed value for clients to set `top_n_tokens`.
|
/// This is the maximum allowed value for clients to set `top_n_tokens`.
|
||||||
/// `top_n_tokens is used to return information about the the `n` most likely
|
/// `top_n_tokens` is used to return information about the the `n` most likely
|
||||||
/// tokens at each generation step, instead of just the sampled token. This
|
/// tokens at each generation step, instead of just the sampled token. This
|
||||||
/// information can be used for downstream tasks like for classification or
|
/// information can be used for downstream tasks like for classification or
|
||||||
/// ranking.
|
/// ranking.
|
||||||
@ -470,7 +500,9 @@ fn shard_manager(
|
|||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
max_input_tokens: usize,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
log_level: LevelFilter,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
_shutdown_sender: mpsc::Sender<()>,
|
_shutdown_sender: mpsc::Sender<()>,
|
||||||
@ -493,7 +525,7 @@ fn shard_manager(
|
|||||||
"--uds-path".to_string(),
|
"--uds-path".to_string(),
|
||||||
uds_path,
|
uds_path,
|
||||||
"--logger-level".to_string(),
|
"--logger-level".to_string(),
|
||||||
"INFO".to_string(),
|
log_level.to_string().to_uppercase(),
|
||||||
"--json-output".to_string(),
|
"--json-output".to_string(),
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -551,6 +583,10 @@ fn shard_manager(
|
|||||||
shard_args.push(otlp_endpoint);
|
shard_args.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||||
|
shard_args.push("--max-input-tokens".to_string());
|
||||||
|
shard_args.push(max_input_tokens.to_string());
|
||||||
|
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
@ -781,13 +817,13 @@ struct PythonLogMessage {
|
|||||||
impl PythonLogMessage {
|
impl PythonLogMessage {
|
||||||
fn trace(&self) {
|
fn trace(&self) {
|
||||||
match self.record.level.name {
|
match self.record.level.name {
|
||||||
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
|
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()),
|
||||||
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
|
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()),
|
||||||
PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
|
PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()),
|
||||||
PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
|
PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()),
|
||||||
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
|
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()),
|
||||||
PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
|
PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()),
|
||||||
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
|
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1007,6 +1043,8 @@ fn spawn_shards(
|
|||||||
args: &Args,
|
args: &Args,
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_log_level: LevelFilter,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
shutdown_sender: mpsc::Sender<()>,
|
shutdown_sender: mpsc::Sender<()>,
|
||||||
@ -1067,7 +1105,9 @@ fn spawn_shards(
|
|||||||
rope_factor,
|
rope_factor,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
max_input_tokens,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
|
max_log_level,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
shutdown_sender,
|
shutdown_sender,
|
||||||
@ -1298,8 +1338,22 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
let args: Args = Args::parse();
|
let args: Args = Args::parse();
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
// Filter events with LOG_LEVEL
|
||||||
let env_filter =
|
let varname = "LOG_LEVEL";
|
||||||
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||||
|
// Override to avoid simple logs to be spammed with tokio level informations
|
||||||
|
let log_level = match &log_level[..] {
|
||||||
|
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
||||||
|
"info" => "text_generation_launcher=info,text_generation_router=info",
|
||||||
|
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
||||||
|
log_level => log_level,
|
||||||
|
};
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.parse_lossy(log_level)
|
||||||
|
} else {
|
||||||
|
EnvFilter::new("info")
|
||||||
|
};
|
||||||
|
let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
|
||||||
|
|
||||||
if args.json_output {
|
if args.json_output {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
@ -1342,13 +1396,13 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: Config = serde_json::from_str(&content)?;
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
let config: Config = config.into();
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
let max_default = 4096;
|
let max_default = 4096;
|
||||||
|
|
||||||
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
|
||||||
if max_position_embeddings > max_default {
|
if max_position_embeddings > max_default {
|
||||||
let max = max_position_embeddings;
|
let max = max_position_embeddings;
|
||||||
if args.max_input_tokens.is_none()
|
if args.max_input_tokens.is_none()
|
||||||
@ -1357,18 +1411,15 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
{
|
{
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
}
|
}
|
||||||
max_default
|
Ok(max_default)
|
||||||
} else {
|
} else {
|
||||||
max_position_embeddings
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
return Err(Box::new(LauncherError::ArgumentValidation(
|
|
||||||
"no max defined".to_string(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(max_position_embeddings)
|
Ok(max_position_embeddings)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(Box::new(LauncherError::ArgumentValidation(
|
||||||
|
"no max defined".to_string(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||||
|
|
||||||
@ -1462,6 +1513,11 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
|
|
||||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||||
if num_shard > 1 {
|
if num_shard > 1 {
|
||||||
|
if matches!(args.quantize, Some(Quantization::Exl2)) {
|
||||||
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"Sharding is currently not supported with `exl2` quantization".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
tracing::info!("Sharding model on {num_shard} processes");
|
tracing::info!("Sharding model on {num_shard} processes");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1524,6 +1580,8 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
&args,
|
&args,
|
||||||
cuda_graphs,
|
cuda_graphs,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_log_level,
|
||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
shutdown_sender,
|
shutdown_sender,
|
||||||
|
265
proto/v3/generate.proto
Normal file
265
proto/v3/generate.proto
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package generate.v3;
|
||||||
|
|
||||||
|
service TextGenerationService {
|
||||||
|
/// Model Info
|
||||||
|
rpc Info (InfoRequest) returns (InfoResponse) {}
|
||||||
|
/// Service discovery
|
||||||
|
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||||
|
/// Empties batch cache
|
||||||
|
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||||
|
/// Remove requests from a cached batch
|
||||||
|
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||||
|
/// Warmup the model and compute max cache size
|
||||||
|
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||||
|
/// Prefill batch and decode first token
|
||||||
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
|
/// Decode token for a list of prefilled batches
|
||||||
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
|
/// Health check
|
||||||
|
rpc Health (HealthRequest) returns (HealthResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
message HealthRequest {}
|
||||||
|
message HealthResponse {}
|
||||||
|
|
||||||
|
/// Empty request
|
||||||
|
message InfoRequest {}
|
||||||
|
|
||||||
|
message InfoResponse {
|
||||||
|
bool requires_padding = 1;
|
||||||
|
string dtype = 2;
|
||||||
|
string device_type = 3;
|
||||||
|
optional uint32 window_size = 4;
|
||||||
|
uint32 speculate = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty request
|
||||||
|
message ServiceDiscoveryRequest {}
|
||||||
|
|
||||||
|
message ServiceDiscoveryResponse {
|
||||||
|
/// Other shards urls
|
||||||
|
repeated string urls = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClearCacheRequest {
|
||||||
|
/// Optional batch id
|
||||||
|
optional uint64 id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty response
|
||||||
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
|
message Image {
|
||||||
|
/// Binary image data.
|
||||||
|
bytes data = 1;
|
||||||
|
|
||||||
|
/// Image MIME type.
|
||||||
|
string mimetype = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message InputChunk {
|
||||||
|
oneof chunk {
|
||||||
|
/// Plain text data
|
||||||
|
string text = 1;
|
||||||
|
/// Image data
|
||||||
|
Image image = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message Input {
|
||||||
|
repeated InputChunk chunks = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum GrammarType {
|
||||||
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
|
GRAMMAR_TYPE_JSON = 1;
|
||||||
|
GRAMMAR_TYPE_REGEX = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message NextTokenChooserParameters {
|
||||||
|
/// exponential scaling output probability distribution
|
||||||
|
float temperature = 1;
|
||||||
|
/// restricting to the k highest probability elements
|
||||||
|
uint32 top_k = 2;
|
||||||
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
float top_p = 3;
|
||||||
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
float typical_p = 4;
|
||||||
|
/// apply sampling on the logits
|
||||||
|
bool do_sample = 5;
|
||||||
|
/// random seed for sampling
|
||||||
|
uint64 seed = 6;
|
||||||
|
/// repetition penalty
|
||||||
|
float repetition_penalty = 7;
|
||||||
|
/// frequency penalty
|
||||||
|
float frequency_penalty = 9;
|
||||||
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
|
bool watermark = 8;
|
||||||
|
/// grammar (applied if not empty)
|
||||||
|
string grammar = 10;
|
||||||
|
/// grammar type
|
||||||
|
GrammarType grammar_type = 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
message StoppingCriteriaParameters {
|
||||||
|
/// Maximum number of generated tokens
|
||||||
|
uint32 max_new_tokens = 1;
|
||||||
|
/// Optional stopping sequences
|
||||||
|
repeated string stop_sequences = 2;
|
||||||
|
/// Ignore end of sequence token
|
||||||
|
/// used for benchmarking
|
||||||
|
bool ignore_eos_token = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Request {
|
||||||
|
/// Request ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// The generation context as chunks
|
||||||
|
Input input_chunks = 8;
|
||||||
|
/// The generation context, stringified input_chunks
|
||||||
|
string inputs = 2;
|
||||||
|
/// Context truncation
|
||||||
|
uint32 truncate = 3;
|
||||||
|
/// Next Token Chooser Parameters
|
||||||
|
NextTokenChooserParameters parameters = 4;
|
||||||
|
/// Stopping Criteria Parameters
|
||||||
|
StoppingCriteriaParameters stopping_parameters = 5;
|
||||||
|
/// Return prefill logprobs
|
||||||
|
bool prefill_logprobs = 6;
|
||||||
|
/// Return most likely n tokens
|
||||||
|
uint32 top_n_tokens = 7;
|
||||||
|
/// Paged attention blocks
|
||||||
|
repeated uint32 blocks = 9;
|
||||||
|
/// Paged attention slots
|
||||||
|
repeated uint32 slots = 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Batch {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// Individual requests
|
||||||
|
repeated Request requests = 2;
|
||||||
|
/// Batch size (==len(requests))
|
||||||
|
uint32 size = 3;
|
||||||
|
/// Maximum number of tokens this batch will grow to
|
||||||
|
uint32 max_tokens = 4;
|
||||||
|
/// Maximum number of Paged Attention blocks
|
||||||
|
uint32 max_blocks = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CachedBatch {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// Individual requests ids
|
||||||
|
repeated uint64 request_ids = 2;
|
||||||
|
/// Batch size (==len(requests))
|
||||||
|
uint32 size = 3;
|
||||||
|
/// Maximum number of tokens this batch will grow to
|
||||||
|
uint32 max_tokens = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum FinishReason {
|
||||||
|
FINISH_REASON_LENGTH = 0;
|
||||||
|
FINISH_REASON_EOS_TOKEN = 1;
|
||||||
|
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GeneratedText {
|
||||||
|
/// Output
|
||||||
|
string text = 1;
|
||||||
|
/// Number of generated tokens
|
||||||
|
uint32 generated_tokens = 2;
|
||||||
|
/// Finish reason
|
||||||
|
FinishReason finish_reason = 3;
|
||||||
|
/// Seed
|
||||||
|
optional uint64 seed = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Tokens {
|
||||||
|
/// Token IDs
|
||||||
|
repeated uint32 ids = 1;
|
||||||
|
/// Logprobs
|
||||||
|
repeated float logprobs = 2;
|
||||||
|
/// tokens
|
||||||
|
repeated string texts = 3;
|
||||||
|
/// special
|
||||||
|
repeated bool is_special = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Generation {
|
||||||
|
/// Request ID
|
||||||
|
uint64 request_id = 1;
|
||||||
|
/// Prefill tokens (optional)
|
||||||
|
Tokens prefill_tokens = 2;
|
||||||
|
Tokens tokens = 3;
|
||||||
|
/// Complete generated text
|
||||||
|
optional GeneratedText generated_text = 4;
|
||||||
|
/// Top tokens
|
||||||
|
repeated Tokens top_tokens = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FilterBatchRequest {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 batch_id = 1;
|
||||||
|
/// Requests to keep
|
||||||
|
repeated uint64 request_ids = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FilterBatchResponse {
|
||||||
|
/// Filtered Batch (cached)
|
||||||
|
CachedBatch batch = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message PrefillRequest {
|
||||||
|
/// Batch
|
||||||
|
Batch batch = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PrefillResponse {
|
||||||
|
/// Generation
|
||||||
|
repeated Generation generations = 1;
|
||||||
|
/// Next batch (cached)
|
||||||
|
optional CachedBatch batch = 2;
|
||||||
|
/// Forward elapsed time in nanoseconds
|
||||||
|
uint64 forward_ns = 3;
|
||||||
|
/// Decode elapsed time in nanoseconds
|
||||||
|
uint64 decode_ns = 4;
|
||||||
|
/// Total elapsed time in nanoseconds
|
||||||
|
uint64 total_ns = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DecodeRequest {
|
||||||
|
/// Cached batches
|
||||||
|
repeated CachedBatch batches = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DecodeResponse {
|
||||||
|
/// Decodes
|
||||||
|
repeated Generation generations = 1;
|
||||||
|
/// Next batch (cached)
|
||||||
|
optional CachedBatch batch = 2;
|
||||||
|
/// Forward elapsed time in nanoseconds
|
||||||
|
uint64 forward_ns = 3;
|
||||||
|
/// Decode elapsed time in nanoseconds
|
||||||
|
uint64 decode_ns = 4;
|
||||||
|
/// Total elapsed time in nanoseconds
|
||||||
|
uint64 total_ns = 5;
|
||||||
|
/// Concatenate elapsed time in nanoseconds
|
||||||
|
optional uint64 concat_ns = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message WarmupRequest {
|
||||||
|
/// Batch to warmup on
|
||||||
|
Batch batch = 1;
|
||||||
|
uint32 max_input_length = 2;
|
||||||
|
uint32 max_prefill_tokens = 3;
|
||||||
|
uint32 max_total_tokens = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message WarmupResponse {
|
||||||
|
/// Maximum number of tokens supported by the model
|
||||||
|
optional uint32 max_supported_total_tokens = 1;
|
||||||
|
}
|
@ -16,8 +16,8 @@ path = "src/main.rs"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.6.20", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.14.1"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
@ -36,20 +36,21 @@ thiserror = "1.0.48"
|
|||||||
tokenizers = { workspace = true}
|
tokenizers = { workspace = true}
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.14"
|
||||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.21.0"
|
tracing-opentelemetry = "0.21.0"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
|
minijinja = { version = "2.0.2" }
|
||||||
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
image = "0.25.1"
|
image = "0.25.1"
|
||||||
base64 = "0.22.0"
|
base64 = { workspace = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
@ -58,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
|||||||
default = ["ngrok"]
|
default = ["ngrok"]
|
||||||
ngrok = ["dep:ngrok"]
|
ngrok = ["dep:ngrok"]
|
||||||
google = []
|
google = []
|
||||||
|
kserve = []
|
||||||
|
@ -6,6 +6,8 @@ authors.workspace = true
|
|||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait = "^0.1"
|
||||||
|
base64 = { workspace = true }
|
||||||
futures = "^0.3"
|
futures = "^0.3"
|
||||||
grpc-metadata = { path = "../grpc-metadata" }
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
prost = "^0.12"
|
prost = "^0.12"
|
||||||
|
@ -1,18 +1,34 @@
|
|||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
println!("cargo:rerun-if-changed=../../proto/generate.proto");
|
println!("cargo:rerun-if-changed=../../proto/");
|
||||||
fs::create_dir("src/pb").unwrap_or(());
|
|
||||||
|
|
||||||
|
fs::create_dir_all("src/v2/pb").unwrap_or(());
|
||||||
let mut config = prost_build::Config::new();
|
let mut config = prost_build::Config::new();
|
||||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
tonic_build::configure()
|
tonic_build::configure()
|
||||||
.build_client(true)
|
.build_client(true)
|
||||||
.build_server(false)
|
.build_server(false)
|
||||||
.out_dir("src/pb")
|
.out_dir("src/v2/pb")
|
||||||
.include_file("mod.rs")
|
.include_file("mod.rs")
|
||||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||||
|
.map_err(|e| match e.kind(){
|
||||||
|
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
|
||||||
|
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
|
||||||
|
e => {e}
|
||||||
|
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
|
fs::create_dir_all("src/v3/pb").unwrap_or(());
|
||||||
|
let mut config = prost_build::Config::new();
|
||||||
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
|
tonic_build::configure()
|
||||||
|
.build_client(true)
|
||||||
|
.build_server(false)
|
||||||
|
.out_dir("src/v3/pb")
|
||||||
|
.include_file("mod.rs")
|
||||||
|
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,22 +1,35 @@
|
|||||||
//! Text Generation gRPC client library
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
mod client;
|
pub mod v2;
|
||||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
pub mod v3;
|
||||||
mod pb;
|
|
||||||
mod sharded_client;
|
|
||||||
|
|
||||||
pub use client::Client;
|
use async_trait::async_trait;
|
||||||
pub use pb::generate::v2::HealthResponse;
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
|
||||||
pub use pb::generate::v2::{
|
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
|
||||||
};
|
|
||||||
pub use sharded_client::ShardedClient;
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tonic::transport;
|
use tonic::transport;
|
||||||
use tonic::Status;
|
use tonic::Status;
|
||||||
|
|
||||||
|
pub use v3::{Chunk, Image, Input, InputChunk};
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Health {
|
||||||
|
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||||
|
async fn device_health(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Check if a generate server is healthy by doing a forward pass.
|
||||||
|
/// EXPENSIVE
|
||||||
|
async fn model_health(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ShardInfo {
|
||||||
|
pub requires_padding: bool,
|
||||||
|
pub dtype: String,
|
||||||
|
pub device_type: String,
|
||||||
|
pub window_size: Option<u32>,
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
#[error("Could not connect to Text Generation server: {0}")]
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
@ -43,4 +56,36 @@ impl From<transport::Error> for ClientError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Small convenience re-wrapping of `Chunk`.
|
||||||
|
impl From<Chunk> for InputChunk {
|
||||||
|
fn from(chunk: Chunk) -> Self {
|
||||||
|
InputChunk { chunk: Some(chunk) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert input chunks to a stringly-typed input for backwards
|
||||||
|
/// compat for backends that haven't implemented chunked inputs.
|
||||||
|
pub trait ChunksToString {
|
||||||
|
/// Convert chunks to string.
|
||||||
|
fn chunks_to_string(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChunksToString for Vec<InputChunk> {
|
||||||
|
fn chunks_to_string(&self) -> String {
|
||||||
|
let mut output = String::new();
|
||||||
|
self.iter().for_each(|c| match &c.chunk {
|
||||||
|
Some(Chunk::Text(text)) => output.push_str(text),
|
||||||
|
Some(Chunk::Image(Image { data, mimetype })) => {
|
||||||
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!("", mimetype, encoded))
|
||||||
|
}
|
||||||
|
// We don't create empty chunks, so this should be unreachable.
|
||||||
|
None => unreachable!("Chunks should never be empty"),
|
||||||
|
});
|
||||||
|
output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||||
|
1
router/client/src/pb/.gitignore
vendored
1
router/client/src/pb/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
*.rs
|
|
@ -1,8 +1,11 @@
|
|||||||
/// Single shard Client
|
/// Single shard Client
|
||||||
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
use crate::v2::pb;
|
||||||
use crate::pb::generate::v2::*;
|
use crate::{ClientError, Result};
|
||||||
use crate::Result;
|
|
||||||
|
use crate::WARMUP_IMAGE_BASE64;
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||||
|
use pb::generate::v2::*;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
@ -42,7 +45,9 @@ impl Client {
|
|||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||||
let response = self.stub.service_discovery(request).await?;
|
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||||
|
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||||
|
})?;
|
||||||
let urls = response
|
let urls = response
|
||||||
.into_inner()
|
.into_inner()
|
||||||
.urls
|
.urls
|
||||||
@ -118,13 +123,15 @@ impl Client {
|
|||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
inputs.push_str("");
|
inputs.push_str(&format!(
|
||||||
|
"",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
|
||||||
inputs,
|
inputs,
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
truncate,
|
truncate,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
13
router/client/src/v2/mod.rs
Normal file
13
router/client/src/v2/mod.rs
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v2::HealthResponse;
|
||||||
|
pub use pb::generate::v2::{
|
||||||
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
1
router/client/src/v2/pb/.gitignore
vendored
Normal file
1
router/client/src/v2/pb/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*
|
@ -1,10 +1,17 @@
|
|||||||
use crate::client::{DecodeTimings, PrefillTimings};
|
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
use crate::{v2, Health, ShardInfo};
|
||||||
use crate::{ClientError, Result};
|
use crate::{ClientError, Result};
|
||||||
|
|
||||||
|
use crate::v2::InfoResponse;
|
||||||
|
use async_trait::async_trait;
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use v2::client::{DecodeTimings, PrefillTimings};
|
||||||
|
use v2::{
|
||||||
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Text Generation Inference gRPC multi client
|
/// Text Generation Inference gRPC multi client
|
||||||
@ -47,7 +54,7 @@ impl ShardedClient {
|
|||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.info())
|
.map(|client| client.info())
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GRPC health check
|
/// GRPC health check
|
||||||
@ -185,3 +192,60 @@ impl ShardedClient {
|
|||||||
Ok((generations, next_batch, timings))
|
Ok((generations, next_batch, timings))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<InfoResponse> for ShardInfo {
|
||||||
|
fn from(value: InfoResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
requires_padding: value.requires_padding,
|
||||||
|
dtype: value.dtype,
|
||||||
|
device_type: value.device_type,
|
||||||
|
window_size: value.window_size,
|
||||||
|
speculate: value.speculate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Health for ShardedClient {
|
||||||
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
self.clone().health().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_health(&self) -> Result<()> {
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
truncate: 10,
|
||||||
|
prefill_logprobs: false,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
};
|
||||||
|
self.clone().prefill(batch).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
282
router/client/src/v3/client.rs
Normal file
282
router/client/src/v3/client.rs
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
use crate::v3::{pb, Chunk};
|
||||||
|
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
|
/// Single shard Client
|
||||||
|
use base64::engine::general_purpose::STANDARD;
|
||||||
|
use base64::Engine;
|
||||||
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||||
|
use pb::generate::v3::*;
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tonic::transport::{Channel, Uri};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
/// Text Generation Inference gRPC client
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Client {
|
||||||
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Returns a client connected to the given url
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||||
|
.unwrap()
|
||||||
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||||
|
tokio::net::UnixStream::connect(path.clone())
|
||||||
|
}))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a list of uris or unix sockets of all shards
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||||
|
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||||
|
ClientError::Connection("Server does not support v3 interface".to_string())
|
||||||
|
})?;
|
||||||
|
let urls = response
|
||||||
|
.into_inner()
|
||||||
|
.urls
|
||||||
|
.into_iter()
|
||||||
|
// Remove unix socket prefix
|
||||||
|
.map(|url| match url.strip_prefix("unix://") {
|
||||||
|
None => url,
|
||||||
|
Some(stripped_url) => stripped_url.to_string(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(urls)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||||
|
let response = self.stub.info(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model health
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||||
|
let response = self.stub.health(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||||
|
self.stub.clear_cache(request).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
|
batch_id,
|
||||||
|
request_ids,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
Ok(filtered_batch.batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let mut n_tokens = 0;
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
// Create requests
|
||||||
|
while n_tokens < max_prefill_tokens {
|
||||||
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
|
let mut input_chunks = Vec::new();
|
||||||
|
input_chunks
|
||||||
|
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||||
|
if n_tokens == 0 {
|
||||||
|
input_chunks.push(
|
||||||
|
Chunk::Image(Image {
|
||||||
|
// Safe unwrap, because we control the data.
|
||||||
|
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||||
|
mimetype: "image/jpeg;base64".to_string(),
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||||
|
// been updated to support chunks.
|
||||||
|
|
||||||
|
let mut inputs = String::new();
|
||||||
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
if n_tokens == 0 {
|
||||||
|
// 1 request is enough to test vision heads.
|
||||||
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
inputs.push_str(&format!(
|
||||||
|
"",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
requests.push(Request {
|
||||||
|
id: 0,
|
||||||
|
inputs,
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: input_chunks,
|
||||||
|
}),
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
|
truncate,
|
||||||
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
|
blocks: vec![],
|
||||||
|
slots: vec![],
|
||||||
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 0.9,
|
||||||
|
top_k: 10,
|
||||||
|
top_p: 0.9,
|
||||||
|
typical_p: 0.9,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.2,
|
||||||
|
frequency_penalty: 0.1,
|
||||||
|
watermark: true,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: true,
|
||||||
|
}),
|
||||||
|
prefill_logprobs: true,
|
||||||
|
top_n_tokens: 20,
|
||||||
|
});
|
||||||
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
// Check max_batch_size
|
||||||
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
size: requests.len() as u32,
|
||||||
|
requests,
|
||||||
|
max_tokens: max_input_length,
|
||||||
|
max_blocks: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
Ok(response.max_supported_total_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
DecodeTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PrefillTimings {
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefillTimings {
|
||||||
|
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DecodeTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecodeTimings {
|
||||||
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
13
router/client/src/v3/mod.rs
Normal file
13
router/client/src/v3/mod.rs
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v3::{
|
||||||
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
|
StoppingCriteriaParameters, Tokens,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
1
router/client/src/v3/pb/.gitignore
vendored
Normal file
1
router/client/src/v3/pb/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*
|
258
router/client/src/v3/sharded_client.rs
Normal file
258
router/client/src/v3/sharded_client.rs
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
/// Multi shard Client
|
||||||
|
use crate::{v3, Health, ShardInfo};
|
||||||
|
use crate::{ClientError, Result};
|
||||||
|
|
||||||
|
use crate::v3::{Chunk, InfoResponse, Input};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use tonic::transport::Uri;
|
||||||
|
use tracing::instrument;
|
||||||
|
use v3::client::{DecodeTimings, PrefillTimings};
|
||||||
|
use v3::{
|
||||||
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Text Generation Inference gRPC multi client
|
||||||
|
pub struct ShardedClient {
|
||||||
|
clients: Vec<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShardedClient {
|
||||||
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
|
Self { clients }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
// Get all uris/unix sockets from the master client
|
||||||
|
let uris = master_client.service_discovery().await?;
|
||||||
|
let futures = uris.into_iter().map(Client::connect_uds);
|
||||||
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
|
Ok(Self::new(clients?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given uri
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let master_client = Client::connect(uri).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let master_client = Client::connect_uds(path).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.info())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.clear_cache(batch_id))
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// Take the minimum value
|
||||||
|
let results = join_all(futures)
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||||
|
Ok(results.into_iter().flatten().min())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InfoResponse> for ShardInfo {
|
||||||
|
fn from(value: InfoResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
requires_padding: value.requires_padding,
|
||||||
|
dtype: value.dtype,
|
||||||
|
device_type: value.device_type,
|
||||||
|
window_size: value.window_size,
|
||||||
|
speculate: value.speculate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Health for ShardedClient {
|
||||||
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
self.clone().health().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_health(&self) -> Result<()> {
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
|
}),
|
||||||
|
truncate: 10,
|
||||||
|
prefill_logprobs: false,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
blocks: vec![0],
|
||||||
|
slots: (0..16).collect(),
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
max_blocks: 1,
|
||||||
|
};
|
||||||
|
self.clone().prefill(batch).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct LlavaNext {
|
pub struct LlavaNext {
|
||||||
text_config: TextConfig,
|
pub(crate) text_config: TextConfig,
|
||||||
vision_config: VisionConfig,
|
pub(crate) vision_config: VisionConfig,
|
||||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_anyres_image_grid_shape(
|
fn get_anyres_image_grid_shape(
|
||||||
@ -119,13 +119,13 @@ impl Idefics2 {
|
|||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct PaliTextConfig {
|
pub struct PaliTextConfig {
|
||||||
num_image_tokens: usize,
|
pub(crate) num_image_tokens: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Paligemma {
|
pub struct Paligemma {
|
||||||
text_config: PaliTextConfig,
|
pub(crate) text_config: PaliTextConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Paligemma {
|
impl Paligemma {
|
||||||
@ -175,8 +175,8 @@ pub struct TextConfig {}
|
|||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct VisionConfig {
|
pub struct VisionConfig {
|
||||||
image_size: usize,
|
pub(crate) image_size: usize,
|
||||||
patch_size: usize,
|
pub(crate) patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -1,72 +0,0 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::GrammarType as ProtoGrammarType;
|
|
||||||
use text_generation_client::{
|
|
||||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Note: Request ids and batch ids cannot collide.
|
|
||||||
const LIVENESS_ID: u64 = u64::MAX;
|
|
||||||
const BATCH_ID: u64 = u64::MAX;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub(crate) struct Health {
|
|
||||||
client: ShardedClient,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Health {
|
|
||||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
|
||||||
Self {
|
|
||||||
client,
|
|
||||||
generation_health,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn check(&mut self) -> bool {
|
|
||||||
if self.generation_health.load(Ordering::SeqCst) {
|
|
||||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
|
||||||
self.client.health().await.is_ok()
|
|
||||||
} else {
|
|
||||||
// Generation is unhealthy or have not sent any generation request yet
|
|
||||||
|
|
||||||
// Dummy batch of 1 token and 1 generated token
|
|
||||||
let liveness_request = Request {
|
|
||||||
id: LIVENESS_ID,
|
|
||||||
inputs: "liveness".to_string(),
|
|
||||||
truncate: 10,
|
|
||||||
prefill_logprobs: false,
|
|
||||||
parameters: Some(NextTokenChooserParameters {
|
|
||||||
temperature: 1.0,
|
|
||||||
top_k: 0,
|
|
||||||
top_p: 1.0,
|
|
||||||
typical_p: 1.0,
|
|
||||||
do_sample: false,
|
|
||||||
seed: 0,
|
|
||||||
repetition_penalty: 1.0,
|
|
||||||
frequency_penalty: 0.0,
|
|
||||||
watermark: false,
|
|
||||||
grammar: String::new(),
|
|
||||||
grammar_type: ProtoGrammarType::None as i32,
|
|
||||||
}),
|
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
|
||||||
max_new_tokens: 1,
|
|
||||||
stop_sequences: vec![],
|
|
||||||
ignore_eos_token: false,
|
|
||||||
}),
|
|
||||||
top_n_tokens: 0,
|
|
||||||
};
|
|
||||||
let batch = Batch {
|
|
||||||
id: BATCH_ID,
|
|
||||||
requests: vec![liveness_request],
|
|
||||||
size: 1,
|
|
||||||
max_tokens: 2,
|
|
||||||
};
|
|
||||||
// Skips the queue
|
|
||||||
let value = self.client.prefill(batch).await.is_ok();
|
|
||||||
// Update generation health
|
|
||||||
self.generation_health.store(value, Ordering::SeqCst);
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
34
router/src/infer/health.rs
Normal file
34
router/src/infer/health.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::Health;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct HealthCheck {
|
||||||
|
client: Arc<dyn Health + Send + Sync>,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HealthCheck {
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: Arc<dyn Health + Send + Sync>,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
generation_health,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn check(&mut self) -> bool {
|
||||||
|
let value = if self.generation_health.load(Ordering::SeqCst) {
|
||||||
|
// Generation is healthy, we only check that the shards can allocate on device
|
||||||
|
self.client.device_health().await
|
||||||
|
} else {
|
||||||
|
self.client.model_health().await
|
||||||
|
}
|
||||||
|
.is_ok();
|
||||||
|
// Update generation health
|
||||||
|
self.generation_health.store(value, Ordering::SeqCst);
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
519
router/src/infer/mod.rs
Normal file
519
router/src/infer/mod.rs
Normal file
@ -0,0 +1,519 @@
|
|||||||
|
mod health;
|
||||||
|
pub(crate) mod v2;
|
||||||
|
pub(crate) mod v3;
|
||||||
|
|
||||||
|
pub(crate) use health::HealthCheck;
|
||||||
|
|
||||||
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
|
use crate::{
|
||||||
|
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||||
|
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
|
||||||
|
};
|
||||||
|
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||||
|
use futures::future::try_join_all;
|
||||||
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
|
use minijinja_contrib::pycompat;
|
||||||
|
|
||||||
|
use serde_json::{json, Map, Value};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
pub(crate) trait Scheduler {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
permit: OwnedSemaphorePermit,
|
||||||
|
) -> Result<GenerateStreamResponse, InferError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inference struct
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Infer {
|
||||||
|
/// Validation
|
||||||
|
validation: Validation,
|
||||||
|
/// Request scheduler
|
||||||
|
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||||
|
/// Chat template
|
||||||
|
chat_template: Option<ChatTemplate>,
|
||||||
|
/// Inference limit
|
||||||
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Infer {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||||
|
validation: Validation,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
tokenizer_config: HubTokenizerConfig,
|
||||||
|
processor_config: HubProcessorConfig,
|
||||||
|
) -> Self {
|
||||||
|
let chat_template = tokenizer_config
|
||||||
|
.chat_template
|
||||||
|
.or(processor_config.chat_template)
|
||||||
|
.and_then(|t| match t {
|
||||||
|
ChatTemplateVersions::Single(template) => Some(template),
|
||||||
|
ChatTemplateVersions::Multiple(templates) => templates
|
||||||
|
.into_iter()
|
||||||
|
.find(|t| t.name == "default")
|
||||||
|
.map(|t| t.template),
|
||||||
|
})
|
||||||
|
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||||
|
|
||||||
|
// Inference limit with a semaphore
|
||||||
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
validation,
|
||||||
|
scheduler,
|
||||||
|
chat_template,
|
||||||
|
limit_concurrent_requests: semaphore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<GenerateStreamResponse, InferError> {
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
let permit = self
|
||||||
|
.clone()
|
||||||
|
.limit_concurrent_requests
|
||||||
|
.try_acquire_owned()
|
||||||
|
.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
self.scheduler.schedule(valid_request, permit)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokenizer the input
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn tokenize(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||||
|
// Tokenize request
|
||||||
|
let inputs = request.inputs;
|
||||||
|
let truncate = request.parameters.truncate;
|
||||||
|
let encoding = self
|
||||||
|
.validation
|
||||||
|
.tokenize(inputs, truncate)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
tracing::error!("Tokenization {err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Return Encoding
|
||||||
|
Ok(encoding.map(|(encoding, _)| encoding))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply the chat template to the chat request
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) fn apply_chat_template(
|
||||||
|
&self,
|
||||||
|
messages: Vec<Message>,
|
||||||
|
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||||
|
) -> Result<String, InferError> {
|
||||||
|
self.chat_template
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
|
.apply(messages, grammar_with_prompt)
|
||||||
|
.map_err(|e| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
|
tracing::error!("{e}");
|
||||||
|
e
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the queue and return a InferResponse
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn generate(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<InferResponse, InferError> {
|
||||||
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
|
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
|
// Return values
|
||||||
|
let mut result_prefill = Vec::new();
|
||||||
|
let mut result_tokens = Vec::new();
|
||||||
|
let mut result_top_tokens = Vec::new();
|
||||||
|
let mut result_generated_text = None;
|
||||||
|
let mut result_start = None;
|
||||||
|
let mut result_queued = None;
|
||||||
|
|
||||||
|
// Iterate on stream
|
||||||
|
while let Some(response) = stream.next().await {
|
||||||
|
match response? {
|
||||||
|
// Add prefill tokens
|
||||||
|
InferStreamResponse::Prefill(prefill_tokens) => {
|
||||||
|
result_prefill = prefill_tokens;
|
||||||
|
}
|
||||||
|
// Push last token
|
||||||
|
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||||
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
|
}
|
||||||
|
// Final message
|
||||||
|
// Set return values
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
generated_text,
|
||||||
|
start,
|
||||||
|
queued,
|
||||||
|
top_tokens,
|
||||||
|
} => {
|
||||||
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
|
result_generated_text = Some(generated_text);
|
||||||
|
result_start = Some(start);
|
||||||
|
result_queued = Some(queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we received a `InferStreamResponse::End` message
|
||||||
|
if let (Some(generated_text), Some(queued), Some(start)) =
|
||||||
|
(result_generated_text, result_queued, result_start)
|
||||||
|
{
|
||||||
|
Ok(InferResponse {
|
||||||
|
prefill: result_prefill,
|
||||||
|
_input_length,
|
||||||
|
tokens: result_tokens,
|
||||||
|
generated_text,
|
||||||
|
queued,
|
||||||
|
start,
|
||||||
|
top_tokens: if use_top_tokens {
|
||||||
|
result_top_tokens
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
let err = InferError::IncompleteGeneration;
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
||||||
|
/// the highest log probability per token
|
||||||
|
#[instrument(skip(self, request))]
|
||||||
|
pub(crate) async fn generate_best_of(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
best_of: usize,
|
||||||
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
||||||
|
// validate best_of parameter separately
|
||||||
|
let best_of = self.validation.validate_best_of(best_of)?;
|
||||||
|
|
||||||
|
// create multiple generate requests
|
||||||
|
let mut infer_responses: Vec<InferResponse> =
|
||||||
|
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
||||||
|
|
||||||
|
// get the sequence with the highest log probability per token
|
||||||
|
let mut max_index = 0;
|
||||||
|
let mut max_logprob: f32 = f32::MIN;
|
||||||
|
|
||||||
|
for (i, response) in infer_responses.iter().enumerate() {
|
||||||
|
// mean logprobs of the generated tokens
|
||||||
|
let sequence_logprob = response
|
||||||
|
.tokens
|
||||||
|
.iter()
|
||||||
|
.map(|token| token.logprob)
|
||||||
|
.sum::<f32>()
|
||||||
|
/ response.tokens.len() as f32;
|
||||||
|
|
||||||
|
// set best sequence
|
||||||
|
if sequence_logprob > max_logprob {
|
||||||
|
max_index = i;
|
||||||
|
max_logprob = sequence_logprob;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let best_response = infer_responses.remove(max_index);
|
||||||
|
Ok((best_response, infer_responses))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raise a exception (custom function) used in the chat templates
|
||||||
|
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||||
|
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ChatTemplate {
|
||||||
|
template: Template<'static, 'static>,
|
||||||
|
bos_token: Option<String>,
|
||||||
|
eos_token: Option<String>,
|
||||||
|
use_default_tool_template: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatTemplate {
|
||||||
|
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||||
|
let mut env = Box::new(Environment::new());
|
||||||
|
// enable things like .strip() or .capitalize()
|
||||||
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
|
let template_str = template.into_boxed_str();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
// check if contains the tools variable within the template
|
||||||
|
let use_default_tool_template =
|
||||||
|
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||||
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
|
let template = Box::leak(env)
|
||||||
|
.template_from_str(Box::leak(template_str))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
template,
|
||||||
|
bos_token,
|
||||||
|
eos_token,
|
||||||
|
use_default_tool_template,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(
|
||||||
|
&self,
|
||||||
|
mut messages: Vec<Message>,
|
||||||
|
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||||
|
) -> Result<String, InferError> {
|
||||||
|
if self.use_default_tool_template {
|
||||||
|
if let Some(last_message) = messages.last_mut() {
|
||||||
|
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||||
|
last_message.content.push(MessageChunk::Text(Text {
|
||||||
|
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
|
self.template
|
||||||
|
.render(ChatTemplateInputs {
|
||||||
|
messages,
|
||||||
|
bos_token: self.bos_token.as_deref(),
|
||||||
|
eos_token: self.eos_token.as_deref(),
|
||||||
|
add_generation_prompt: true,
|
||||||
|
tools: None,
|
||||||
|
tools_prompt: None,
|
||||||
|
})
|
||||||
|
.map_err(InferError::TemplateError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ToolGrammar {}
|
||||||
|
|
||||||
|
impl ToolGrammar {
|
||||||
|
pub fn apply(
|
||||||
|
tools: Option<Vec<Tool>>,
|
||||||
|
tool_choice: Option<ToolType>,
|
||||||
|
) -> Result<Option<Tools>, InferError> {
|
||||||
|
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||||
|
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
|
let tools_to_use = match tool_choice {
|
||||||
|
ToolType::FunctionName(name) => {
|
||||||
|
vec![req_tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == *name)
|
||||||
|
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||||
|
.clone()]
|
||||||
|
}
|
||||||
|
ToolType::OneOf => req_tools.to_owned(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// adds the error notification function for LLM feedback if required
|
||||||
|
let mut text_response_properties = Map::new();
|
||||||
|
text_response_properties.insert(
|
||||||
|
"error".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
text_response_properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": "notify_error"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
let func = tool.function.clone();
|
||||||
|
|
||||||
|
// Clone the existing parameters, which are expected to be a JSON object
|
||||||
|
let mut params = if let Value::Object(params) = &func.arguments {
|
||||||
|
params.clone()
|
||||||
|
} else {
|
||||||
|
Map::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Insert the function's description at the top level, outside of properties
|
||||||
|
params.insert(
|
||||||
|
"description".to_string(),
|
||||||
|
Value::String(func.description.clone().unwrap_or_default()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Ensure 'properties' exists and is an object
|
||||||
|
let properties = params
|
||||||
|
.entry("properties".to_string())
|
||||||
|
.or_insert_with(|| json!({}))
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Insert the constant for the function name inside 'properties'
|
||||||
|
properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": func.name.clone(),
|
||||||
|
// "description": "The name of the function"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||||
|
let required = params
|
||||||
|
.entry("required".to_string())
|
||||||
|
.or_insert_with(|| json!([]))
|
||||||
|
.as_array_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Add 'name' to the 'required' array if it is not already present
|
||||||
|
if !required.iter().any(|r| r == "_name") {
|
||||||
|
required.push(json!("_name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
(func.name, Value::Object(params))
|
||||||
|
})
|
||||||
|
.chain([(
|
||||||
|
"notify_error".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"properties": text_response_properties,
|
||||||
|
"required": ["error", "_name"],
|
||||||
|
"type": "object"
|
||||||
|
}),
|
||||||
|
)])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let tools = Tools {
|
||||||
|
functions_map: FunctionsMap { functions },
|
||||||
|
properties: Properties {
|
||||||
|
function: tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| FunctionRef {
|
||||||
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
|
})
|
||||||
|
.chain(std::iter::once(FunctionRef {
|
||||||
|
ref_path: "#/$functions/notify_error".to_string(),
|
||||||
|
}))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok(Some(tools));
|
||||||
|
}
|
||||||
|
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Type alias for generation responses
|
||||||
|
pub(crate) type GenerateStreamResponse = (
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
|
);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct GeneratedText {
|
||||||
|
pub(crate) text: String,
|
||||||
|
pub(crate) generated_tokens: u32,
|
||||||
|
pub(crate) finish_reason: FinishReason,
|
||||||
|
pub(crate) seed: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) enum InferStreamResponse {
|
||||||
|
// Optional first message
|
||||||
|
Prefill(Vec<PrefillToken>),
|
||||||
|
// Intermediate messages
|
||||||
|
Intermediate {
|
||||||
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
|
},
|
||||||
|
// Last message
|
||||||
|
End {
|
||||||
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
|
generated_text: GeneratedText,
|
||||||
|
start: Instant,
|
||||||
|
queued: Instant,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct InferResponse {
|
||||||
|
/// input_length is the input as perceived by the rust tokenizer in the
|
||||||
|
/// validation pathway. It is redundant with prefill.len() but prefill
|
||||||
|
/// has data only if the user asked for it. This will always be filled.
|
||||||
|
pub(crate) _input_length: u32,
|
||||||
|
pub(crate) prefill: Vec<PrefillToken>,
|
||||||
|
pub(crate) tokens: Vec<Token>,
|
||||||
|
pub(crate) generated_text: GeneratedText,
|
||||||
|
pub(crate) queued: Instant,
|
||||||
|
pub(crate) start: Instant,
|
||||||
|
pub(crate) top_tokens: Vec<Vec<Token>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum InferError {
|
||||||
|
#[error("Request failed during generation: {0}")]
|
||||||
|
GenerationError(String),
|
||||||
|
#[error("Model is overloaded")]
|
||||||
|
Overloaded(#[from] TryAcquireError),
|
||||||
|
#[error("Input validation error: {0}")]
|
||||||
|
ValidationError(#[from] ValidationError),
|
||||||
|
#[error("Incomplete generation")]
|
||||||
|
IncompleteGeneration,
|
||||||
|
#[error("Template error: {0}")]
|
||||||
|
TemplateError(#[from] minijinja::Error),
|
||||||
|
#[error("Tool error: {0}")]
|
||||||
|
ToolError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InferError {
|
||||||
|
pub(crate) fn error_type(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
InferError::GenerationError(_) => "generation",
|
||||||
|
InferError::Overloaded(_) => "overloaded",
|
||||||
|
InferError::ValidationError(_) => "validation",
|
||||||
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
|
InferError::TemplateError(_) => "template_error",
|
||||||
|
InferError::ToolError(_) => "tool_error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
4
router/src/infer/v2/mod.rs
Normal file
4
router/src/infer/v2/mod.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
mod queue;
|
||||||
|
mod scheduler;
|
||||||
|
|
||||||
|
pub(crate) use scheduler::SchedulerV2;
|
@ -1,10 +1,14 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::infer::InferStreamResponse;
|
use crate::validation::{
|
||||||
use crate::validation::ValidGenerateRequest;
|
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::v2::{
|
||||||
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use text_generation_client::ChunksToString;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
@ -55,7 +59,6 @@ impl Queue {
|
|||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Append an entry to the queue
|
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
// Send append command to the background task managing the state
|
// Send append command to the background task managing the state
|
||||||
@ -278,10 +281,14 @@ impl State {
|
|||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
prefill_logprobs: entry.request.decoder_input_details,
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
entry.request.parameters.clone(),
|
||||||
|
)),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||||
|
entry.request.stopping_parameters.clone(),
|
||||||
|
)),
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
@ -297,7 +304,7 @@ impl State {
|
|||||||
|
|
||||||
// Empty batch
|
// Empty batch
|
||||||
if batch_requests.is_empty() {
|
if batch_requests.is_empty() {
|
||||||
tracing::debug!("Filterered out all entries");
|
tracing::debug!("Filtered out all entries");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,12 +357,46 @@ enum QueueCommand {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||||
|
fn from(value: ValidParameters) -> Self {
|
||||||
|
let (grammar, grammar_type) = match value.grammar {
|
||||||
|
None => (String::new(), GrammarType::None),
|
||||||
|
|
||||||
|
Some(grammar) => match grammar {
|
||||||
|
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||||
|
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
temperature: value.temperature,
|
||||||
|
top_k: value.top_k,
|
||||||
|
top_p: value.top_p,
|
||||||
|
typical_p: value.typical_p,
|
||||||
|
do_sample: value.do_sample,
|
||||||
|
seed: value.seed,
|
||||||
|
repetition_penalty: value.repetition_penalty,
|
||||||
|
frequency_penalty: value.frequency_penalty,
|
||||||
|
watermark: value.watermark,
|
||||||
|
grammar,
|
||||||
|
grammar_type: grammar_type.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||||
|
fn from(value: ValidStoppingParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
max_new_tokens: value.max_new_tokens,
|
||||||
|
stop_sequences: value.stop_sequences,
|
||||||
|
ignore_eos_token: value.ignore_eos_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use text_generation_client::{
|
|
||||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
@ -366,11 +407,11 @@ mod tests {
|
|||||||
|
|
||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: String::new(),
|
inputs: vec![],
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
parameters: NextTokenChooserParameters {
|
parameters: ValidParameters {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
top_p: 0.0,
|
top_p: 0.0,
|
||||||
@ -380,10 +421,9 @@ mod tests {
|
|||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: None,
|
||||||
grammar_type: ProtoGrammarType::None as i32,
|
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: ValidStoppingParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
@ -1,78 +1,46 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::infer::v2::queue::{Entry, Queue};
|
||||||
use crate::{
|
use crate::infer::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
||||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
|
|
||||||
};
|
};
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
use crate::validation::ValidGenerateRequest;
|
||||||
use futures::future::try_join_all;
|
use crate::{FinishReason, PrefillToken, Token};
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use serde_json::{json, Map, Value};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::{
|
use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
|
||||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
|
use text_generation_client::ClientError;
|
||||||
};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
/// Inference struct
|
pub(crate) struct SchedulerV2 {
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Infer {
|
|
||||||
/// Validation
|
|
||||||
validation: Validation,
|
|
||||||
/// Request queue
|
/// Request queue
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
/// Shared state
|
/// Notify batcher on queue appends
|
||||||
shared: Arc<Shared>,
|
batching_task_notifier: Arc<Notify>,
|
||||||
/// Chat template
|
|
||||||
chat_template: Option<ChatTemplate>,
|
|
||||||
/// Inference limit
|
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
impl SchedulerV2 {
|
||||||
struct Shared {
|
|
||||||
/// Batching background Tokio task notifier
|
|
||||||
batching_task: Notify,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Raise a exception (custom function) used in the chat templates
|
|
||||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
|
||||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Infer {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
validation: Validation,
|
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
max_concurrent_requests: usize,
|
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
|
||||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
let shared = Arc::new(Shared {
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
batching_task: Notify::new(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
@ -83,68 +51,31 @@ impl Infer {
|
|||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
shared.clone(),
|
batching_task_notifier.clone(),
|
||||||
generation_health,
|
generation_health,
|
||||||
));
|
));
|
||||||
|
|
||||||
let chat_template = tokenizer_config
|
|
||||||
.chat_template
|
|
||||||
.and_then(|t| match t {
|
|
||||||
ChatTemplateVersions::Single(template) => Some(template),
|
|
||||||
ChatTemplateVersions::Multiple(templates) => templates
|
|
||||||
.into_iter()
|
|
||||||
.find(|t| t.name == "default")
|
|
||||||
.map(|t| t.template),
|
|
||||||
})
|
|
||||||
.map(|t| {
|
|
||||||
// .strip() is not supported in minijinja
|
|
||||||
let t = t.replace(".strip()", " | trim");
|
|
||||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
|
||||||
queue,
|
queue,
|
||||||
shared,
|
batching_task_notifier,
|
||||||
chat_template,
|
}
|
||||||
limit_concurrent_requests: semaphore,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
impl Scheduler for SchedulerV2 {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate_stream(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
|
permit: OwnedSemaphorePermit,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<GenerateStreamResponse, InferError> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
||||||
let permit = self
|
|
||||||
.clone()
|
|
||||||
.limit_concurrent_requests
|
|
||||||
.try_acquire_owned()
|
|
||||||
.map_err(|err| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Validate request
|
|
||||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// MPSC channel to communicate with the background batching task
|
// MPSC channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
let input_length = valid_request.input_length;
|
let input_length = request.input_length;
|
||||||
|
|
||||||
// Append the request to the queue
|
// Append the request to the queue
|
||||||
self.queue.append(Entry {
|
self.queue.append(Entry {
|
||||||
request: valid_request,
|
request,
|
||||||
response_tx,
|
response_tx,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
temp_span: None,
|
temp_span: None,
|
||||||
@ -154,7 +85,7 @@ impl Infer {
|
|||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
// to be batched
|
// to be batched
|
||||||
self.shared.batching_task.notify_one();
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
Ok((
|
Ok((
|
||||||
@ -163,343 +94,6 @@ impl Infer {
|
|||||||
UnboundedReceiverStream::new(response_rx),
|
UnboundedReceiverStream::new(response_rx),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tokenizer the input
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) async fn tokenize(
|
|
||||||
&self,
|
|
||||||
request: GenerateRequest,
|
|
||||||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
|
||||||
// Tokenize request
|
|
||||||
let inputs = request.inputs;
|
|
||||||
let truncate = request.parameters.truncate;
|
|
||||||
let encoding = self
|
|
||||||
.validation
|
|
||||||
.tokenize(inputs, truncate)
|
|
||||||
.await
|
|
||||||
.map_err(|err| {
|
|
||||||
tracing::error!("Tokenization {err}");
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Return Encoding
|
|
||||||
Ok(encoding.map(|(encoding, _)| encoding))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Apply the chat template to the chat request
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) fn apply_chat_template(
|
|
||||||
&self,
|
|
||||||
messages: Vec<Message>,
|
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
self.chat_template
|
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
|
||||||
.apply(messages, grammar_with_prompt)
|
|
||||||
.map_err(|e| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
|
||||||
tracing::error!("{e}");
|
|
||||||
e
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) async fn generate(
|
|
||||||
&self,
|
|
||||||
request: GenerateRequest,
|
|
||||||
) -> Result<InferResponse, InferError> {
|
|
||||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
|
||||||
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
|
||||||
|
|
||||||
// Return values
|
|
||||||
let mut result_prefill = Vec::new();
|
|
||||||
let mut result_tokens = Vec::new();
|
|
||||||
let mut result_top_tokens = Vec::new();
|
|
||||||
let mut result_generated_text = None;
|
|
||||||
let mut result_start = None;
|
|
||||||
let mut result_queued = None;
|
|
||||||
|
|
||||||
// Iterate on stream
|
|
||||||
while let Some(response) = stream.next().await {
|
|
||||||
match response? {
|
|
||||||
// Add prefill tokens
|
|
||||||
InferStreamResponse::Prefill(tokens) => {
|
|
||||||
// Create Token objects
|
|
||||||
// We do that here instead of in the Python code as Rust for loops are faster
|
|
||||||
result_prefill = tokens
|
|
||||||
.ids
|
|
||||||
.into_iter()
|
|
||||||
.zip(tokens.logprobs.into_iter())
|
|
||||||
.zip(tokens.texts.into_iter())
|
|
||||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
|
||||||
.collect();
|
|
||||||
}
|
|
||||||
// Push last token
|
|
||||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
|
||||||
result_tokens.push(token);
|
|
||||||
result_top_tokens.push(top_tokens);
|
|
||||||
}
|
|
||||||
// Final message
|
|
||||||
// Set return values
|
|
||||||
InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
generated_text,
|
|
||||||
start,
|
|
||||||
queued,
|
|
||||||
top_tokens,
|
|
||||||
} => {
|
|
||||||
result_tokens.push(token);
|
|
||||||
result_top_tokens.push(top_tokens);
|
|
||||||
result_generated_text = Some(generated_text);
|
|
||||||
result_start = Some(start);
|
|
||||||
result_queued = Some(queued)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that we received a `InferStreamResponse::End` message
|
|
||||||
if let (Some(generated_text), Some(queued), Some(start)) =
|
|
||||||
(result_generated_text, result_queued, result_start)
|
|
||||||
{
|
|
||||||
Ok(InferResponse {
|
|
||||||
prefill: result_prefill,
|
|
||||||
_input_length,
|
|
||||||
tokens: result_tokens,
|
|
||||||
generated_text,
|
|
||||||
queued,
|
|
||||||
start,
|
|
||||||
top_tokens: if use_top_tokens {
|
|
||||||
result_top_tokens
|
|
||||||
} else {
|
|
||||||
Vec::new()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
let err = InferError::IncompleteGeneration;
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
Err(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
|
||||||
/// the highest log probability per token
|
|
||||||
#[instrument(skip(self, request))]
|
|
||||||
pub(crate) async fn generate_best_of(
|
|
||||||
&self,
|
|
||||||
request: GenerateRequest,
|
|
||||||
best_of: usize,
|
|
||||||
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
|
||||||
// validate best_of parameter separately
|
|
||||||
let best_of = self.validation.validate_best_of(best_of)?;
|
|
||||||
|
|
||||||
// create multiple generate requests
|
|
||||||
let mut infer_responses: Vec<InferResponse> =
|
|
||||||
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
|
||||||
|
|
||||||
// get the sequence with the highest log probability per token
|
|
||||||
let mut max_index = 0;
|
|
||||||
let mut max_logprob: f32 = f32::MIN;
|
|
||||||
|
|
||||||
for (i, response) in infer_responses.iter().enumerate() {
|
|
||||||
// mean logprobs of the generated tokens
|
|
||||||
let sequence_logprob = response
|
|
||||||
.tokens
|
|
||||||
.iter()
|
|
||||||
.map(|token| token.logprob)
|
|
||||||
.sum::<f32>()
|
|
||||||
/ response.tokens.len() as f32;
|
|
||||||
|
|
||||||
// set best sequence
|
|
||||||
if sequence_logprob > max_logprob {
|
|
||||||
max_index = i;
|
|
||||||
max_logprob = sequence_logprob;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let best_response = infer_responses.remove(max_index);
|
|
||||||
Ok((best_response, infer_responses))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct ChatTemplate {
|
|
||||||
template: Template<'static, 'static>,
|
|
||||||
bos_token: Option<String>,
|
|
||||||
eos_token: Option<String>,
|
|
||||||
use_default_tool_template: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatTemplate {
|
|
||||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
|
||||||
let mut env = Box::new(Environment::new());
|
|
||||||
let template_str = template.into_boxed_str();
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
|
||||||
|
|
||||||
// check if contains the tools variable within the template
|
|
||||||
let use_default_tool_template =
|
|
||||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
|
||||||
let template = Box::leak(env)
|
|
||||||
.template_from_str(Box::leak(template_str))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
template,
|
|
||||||
bos_token,
|
|
||||||
eos_token,
|
|
||||||
use_default_tool_template,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply(
|
|
||||||
&self,
|
|
||||||
mut messages: Vec<Message>,
|
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text(Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
self.template
|
|
||||||
.render(ChatTemplateInputs {
|
|
||||||
messages,
|
|
||||||
bos_token: self.bos_token.as_deref(),
|
|
||||||
eos_token: self.eos_token.as_deref(),
|
|
||||||
add_generation_prompt: true,
|
|
||||||
tools: None,
|
|
||||||
tools_prompt: None,
|
|
||||||
})
|
|
||||||
.map_err(InferError::TemplateError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ToolGrammar {}
|
|
||||||
|
|
||||||
impl ToolGrammar {
|
|
||||||
pub fn apply(
|
|
||||||
tools: Option<Vec<Tool>>,
|
|
||||||
tool_choice: Option<ToolType>,
|
|
||||||
) -> Result<Option<Tools>, InferError> {
|
|
||||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
|
||||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
|
||||||
let tools_to_use = match tool_choice {
|
|
||||||
ToolType::FunctionName(name) => {
|
|
||||||
vec![req_tools
|
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == *name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
|
||||||
.clone()]
|
|
||||||
}
|
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| {
|
|
||||||
let func = tool.function.clone();
|
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
|
||||||
params.clone()
|
|
||||||
} else {
|
|
||||||
Map::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
|
||||||
params.insert(
|
|
||||||
"description".to_string(),
|
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
|
||||||
let properties = params
|
|
||||||
.entry("properties".to_string())
|
|
||||||
.or_insert_with(|| json!({}))
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
|
||||||
properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": func.name.clone(),
|
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
|
||||||
let required = params
|
|
||||||
.entry("required".to_string())
|
|
||||||
.or_insert_with(|| json!([]))
|
|
||||||
.as_array_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
|
||||||
if !required.iter().any(|r| r == "_name") {
|
|
||||||
required.push(json!("_name"));
|
|
||||||
}
|
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
|
||||||
})
|
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
|
||||||
)])
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = Tools {
|
|
||||||
functions_map: FunctionsMap { functions },
|
|
||||||
properties: Properties {
|
|
||||||
function: tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| FunctionRef {
|
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
|
||||||
})
|
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(Some(tools));
|
|
||||||
}
|
|
||||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
@ -507,7 +101,7 @@ impl ToolGrammar {
|
|||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn batching_task(
|
pub(crate) async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
@ -515,13 +109,13 @@ async fn batching_task(
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
notifier: Arc<Notify>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
// Wait for a notification from the Infer struct
|
// Wait for a notification from the Infer struct
|
||||||
shared.batching_task.notified().await;
|
notifier.notified().await;
|
||||||
|
|
||||||
// Get the next batch from the queue
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
@ -787,6 +381,16 @@ fn send_responses(
|
|||||||
let mut stopped = false;
|
let mut stopped = false;
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs)
|
||||||
|
.zip(prefill_tokens.texts)
|
||||||
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Send message
|
// Send message
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
@ -837,7 +441,7 @@ fn send_responses(
|
|||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text: generated_text.clone(),
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
queued: entry.queue_time,
|
queued: entry.queue_time,
|
||||||
start: entry.batch_time.unwrap(),
|
start: entry.batch_time.unwrap(),
|
||||||
}))?;
|
}))?;
|
||||||
@ -872,64 +476,21 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
impl From<text_generation_client::v2::GeneratedText> for GeneratedText {
|
||||||
pub(crate) enum InferStreamResponse {
|
fn from(value: text_generation_client::v2::GeneratedText) -> Self {
|
||||||
// Optional first message
|
let v2_finish_reason =
|
||||||
Prefill(Tokens),
|
text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
// Intermediate messages
|
let finish_reason = match v2_finish_reason {
|
||||||
Intermediate {
|
text_generation_client::v2::FinishReason::Length => FinishReason::Length,
|
||||||
token: Token,
|
text_generation_client::v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
top_tokens: Vec<Token>,
|
text_generation_client::v2::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
},
|
};
|
||||||
// Last message
|
|
||||||
End {
|
|
||||||
token: Token,
|
|
||||||
top_tokens: Vec<Token>,
|
|
||||||
generated_text: GeneratedText,
|
|
||||||
start: Instant,
|
|
||||||
queued: Instant,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
Self {
|
||||||
pub(crate) struct InferResponse {
|
text: value.text,
|
||||||
/// input_length is the input as perceived by the rust tokenizer in the
|
generated_tokens: value.generated_tokens,
|
||||||
/// validation pathway. It is redundant with prefill.len() but prefill
|
finish_reason,
|
||||||
/// has data only if the user asked for it. This will always be filled.
|
seed: value.seed,
|
||||||
pub(crate) _input_length: u32,
|
|
||||||
pub(crate) prefill: Vec<PrefillToken>,
|
|
||||||
pub(crate) tokens: Vec<Token>,
|
|
||||||
pub(crate) generated_text: GeneratedText,
|
|
||||||
pub(crate) queued: Instant,
|
|
||||||
pub(crate) start: Instant,
|
|
||||||
pub(crate) top_tokens: Vec<Vec<Token>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum InferError {
|
|
||||||
#[error("Request failed during generation: {0}")]
|
|
||||||
GenerationError(String),
|
|
||||||
#[error("Model is overloaded")]
|
|
||||||
Overloaded(#[from] TryAcquireError),
|
|
||||||
#[error("Input validation error: {0}")]
|
|
||||||
ValidationError(#[from] ValidationError),
|
|
||||||
#[error("Incomplete generation")]
|
|
||||||
IncompleteGeneration,
|
|
||||||
#[error("Template error: {0}")]
|
|
||||||
TemplateError(#[from] minijinja::Error),
|
|
||||||
#[error("Tool error: {0}")]
|
|
||||||
ToolError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InferError {
|
|
||||||
pub(crate) fn error_type(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
InferError::GenerationError(_) => "generation",
|
|
||||||
InferError::Overloaded(_) => "overloaded",
|
|
||||||
InferError::ValidationError(_) => "validation",
|
|
||||||
InferError::IncompleteGeneration => "incomplete_generation",
|
|
||||||
InferError::TemplateError(_) => "template_error",
|
|
||||||
InferError::ToolError(_) => "tool_error",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
136
router/src/infer/v3/block_allocator.rs
Normal file
136
router/src/infer/v3/block_allocator.rs
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
use std::cmp::min;
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct BlockAllocation {
|
||||||
|
pub blocks: Vec<u32>,
|
||||||
|
pub slots: Vec<u32>,
|
||||||
|
block_allocator: BlockAllocator,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for BlockAllocation {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.block_allocator.free(self.blocks.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct BlockAllocator {
|
||||||
|
/// Channel to communicate with the background task
|
||||||
|
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BlockAllocator {
|
||||||
|
pub(crate) fn new(
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
) -> Self {
|
||||||
|
// Create channel
|
||||||
|
let (sender, receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Launch background queue task
|
||||||
|
tokio::spawn(block_allocator_task(
|
||||||
|
max_batch_total_tokens / block_size,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
receiver,
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
block_allocator: sender,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
||||||
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
|
self.block_allocator
|
||||||
|
.send(BlockAllocatorCommand::Allocate {
|
||||||
|
tokens,
|
||||||
|
response_sender,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
response_receiver
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.map(|(blocks, slots)| BlockAllocation {
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
block_allocator: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||||
|
self.block_allocator
|
||||||
|
.send(BlockAllocatorCommand::Free { blocks })
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn block_allocator_task(
|
||||||
|
blocks: u32,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||||
|
) {
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
||||||
|
while let Some(cmd) = receiver.recv().await {
|
||||||
|
match cmd {
|
||||||
|
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||||
|
BlockAllocatorCommand::Allocate {
|
||||||
|
tokens,
|
||||||
|
response_sender,
|
||||||
|
} => {
|
||||||
|
// Apply window size
|
||||||
|
let (required_blocks, repeats) = {
|
||||||
|
let (tokens, repeats) = match window_size {
|
||||||
|
None => (tokens, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
let repeats = (tokens + window_size - 1) / window_size;
|
||||||
|
let tokens = min(tokens, window_size);
|
||||||
|
(tokens, repeats as usize)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Pad to a multiple of block size
|
||||||
|
let required_blocks = (tokens + block_size - 1) / block_size;
|
||||||
|
(required_blocks, repeats)
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokens = tokens as usize;
|
||||||
|
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let blocks =
|
||||||
|
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||||
|
let mut slots = Vec::with_capacity(
|
||||||
|
(required_blocks * block_size * repeats as u32) as usize,
|
||||||
|
);
|
||||||
|
|
||||||
|
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||||
|
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() == tokens {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some((blocks, slots))
|
||||||
|
};
|
||||||
|
response_sender.send(allocation).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum BlockAllocatorCommand {
|
||||||
|
Free {
|
||||||
|
blocks: Vec<u32>,
|
||||||
|
},
|
||||||
|
Allocate {
|
||||||
|
tokens: u32,
|
||||||
|
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
||||||
|
},
|
||||||
|
}
|
5
router/src/infer/v3/mod.rs
Normal file
5
router/src/infer/v3/mod.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
mod block_allocator;
|
||||||
|
mod queue;
|
||||||
|
mod scheduler;
|
||||||
|
|
||||||
|
pub(crate) use scheduler::SchedulerV3;
|
730
router/src/infer/v3/queue.rs
Normal file
730
router/src/infer/v3/queue.rs
Normal file
@ -0,0 +1,730 @@
|
|||||||
|
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||||
|
use crate::infer::InferError;
|
||||||
|
use crate::infer::InferStreamResponse;
|
||||||
|
use crate::validation::{
|
||||||
|
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
|
};
|
||||||
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
|
use std::cmp::{max, min};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use text_generation_client::v3::{
|
||||||
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use text_generation_client::ChunksToString;
|
||||||
|
use text_generation_client::Input;
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
/// Queue entry
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Entry {
|
||||||
|
/// Request
|
||||||
|
pub request: ValidGenerateRequest,
|
||||||
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
|
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
|
/// Span that will live as long as entry
|
||||||
|
pub span: Span,
|
||||||
|
/// Temporary span used as a guard when logging inference, wait times...
|
||||||
|
pub temp_span: Option<Span>,
|
||||||
|
/// Instant when this entry was queued
|
||||||
|
pub queue_time: Instant,
|
||||||
|
/// Instant when this entry was added to a batch
|
||||||
|
pub batch_time: Option<Instant>,
|
||||||
|
/// Block Allocation
|
||||||
|
pub block_allocation: Option<BlockAllocation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request Queue
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Queue {
|
||||||
|
/// Channel to communicate with the background queue task
|
||||||
|
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Queue {
|
||||||
|
pub(crate) fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
) -> Self {
|
||||||
|
// Create channel
|
||||||
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Launch background queue task
|
||||||
|
tokio::spawn(queue_task(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
queue_receiver,
|
||||||
|
));
|
||||||
|
|
||||||
|
Self { queue_sender }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append an entry to the queue
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
|
// Send append command to the background task managing the state
|
||||||
|
// Unwrap is safe here
|
||||||
|
self.queue_sender
|
||||||
|
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub(crate) async fn next_batch(
|
||||||
|
&self,
|
||||||
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
) -> Option<NextBatch> {
|
||||||
|
// Create response channel
|
||||||
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
|
// Send next batch command to the background task managing the state
|
||||||
|
// Unwrap is safe here
|
||||||
|
self.queue_sender
|
||||||
|
.send(QueueCommand::NextBatch {
|
||||||
|
min_size,
|
||||||
|
max_size,
|
||||||
|
prefill_token_budget,
|
||||||
|
token_budget,
|
||||||
|
response_sender,
|
||||||
|
span: Span::current(),
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
// Await on response channel
|
||||||
|
// Unwrap is safe here
|
||||||
|
response_receiver.await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Background task responsible of the queue state
|
||||||
|
async fn queue_task(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
|
) {
|
||||||
|
let mut state = State::new(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
|
while let Some(cmd) = receiver.recv().await {
|
||||||
|
match cmd {
|
||||||
|
QueueCommand::Append(entry, span) => {
|
||||||
|
span.in_scope(|| state.append(*entry));
|
||||||
|
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||||
|
}
|
||||||
|
QueueCommand::NextBatch {
|
||||||
|
min_size,
|
||||||
|
max_size,
|
||||||
|
prefill_token_budget,
|
||||||
|
token_budget,
|
||||||
|
response_sender,
|
||||||
|
span,
|
||||||
|
} => {
|
||||||
|
let next_batch = state
|
||||||
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
response_sender.send(next_batch).unwrap();
|
||||||
|
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Queue State
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct State {
|
||||||
|
/// Queue entries organized in a Vec
|
||||||
|
entries: VecDeque<(u64, Entry)>,
|
||||||
|
|
||||||
|
/// Id of the next entry
|
||||||
|
next_id: u64,
|
||||||
|
|
||||||
|
/// Id of the next batch
|
||||||
|
next_batch_id: u64,
|
||||||
|
|
||||||
|
/// Paged Attention block size
|
||||||
|
block_size: u32,
|
||||||
|
|
||||||
|
/// Sliding window
|
||||||
|
window_size: Option<u32>,
|
||||||
|
|
||||||
|
/// Speculation amount
|
||||||
|
speculate: u32,
|
||||||
|
|
||||||
|
/// Paged Attention Block Allocation
|
||||||
|
block_allocator: Option<BlockAllocator>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
) -> Self {
|
||||||
|
let block_allocator = (!requires_padding)
|
||||||
|
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
entries: VecDeque::with_capacity(128),
|
||||||
|
next_id: 0,
|
||||||
|
next_batch_id: 0,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
block_allocator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append an entry to the queue
|
||||||
|
fn append(&mut self, mut entry: Entry) {
|
||||||
|
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
||||||
|
let queue_span = info_span!(parent: &entry.span, "queued");
|
||||||
|
entry.temp_span = Some(queue_span);
|
||||||
|
|
||||||
|
// Push entry in the queue
|
||||||
|
self.entries.push_back((self.next_id, entry));
|
||||||
|
self.next_id += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next batch
|
||||||
|
async fn next_batch(
|
||||||
|
&mut self,
|
||||||
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
) -> Option<NextBatch> {
|
||||||
|
if self.entries.is_empty() {
|
||||||
|
tracing::debug!("No queue");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have enough entries
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
if self.entries.len() < min_size {
|
||||||
|
tracing::debug!("Not enough entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pad prefill_token_budget to be a multiple of block size
|
||||||
|
let prefill_token_budget =
|
||||||
|
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||||
|
|
||||||
|
// Create span for this batch to add context to inference calls
|
||||||
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
|
next_batch_span.follows_from(&Span::current());
|
||||||
|
|
||||||
|
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||||
|
let mut batch_entries =
|
||||||
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
|
let mut max_input_length = 0;
|
||||||
|
let mut prefill_tokens: u32 = 0;
|
||||||
|
let mut decode_tokens: u32 = 0;
|
||||||
|
let mut max_blocks = 0;
|
||||||
|
|
||||||
|
// Pop entries starting from the front of the queue
|
||||||
|
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||||
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
|
// was dropped by the client)
|
||||||
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
tracing::debug!("Dropping entry");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let block_allocation = match &self.block_allocator {
|
||||||
|
None => {
|
||||||
|
// We pad to max input length in the Python shards
|
||||||
|
// We need to take these padding tokens into the equation
|
||||||
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
|
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
|
||||||
|
|
||||||
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||||
|
|
||||||
|
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Some(block_allocator) => {
|
||||||
|
prefill_tokens += entry.request.input_length;
|
||||||
|
let max_new_tokens = match self.window_size {
|
||||||
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
Some(window_size) => min(
|
||||||
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
decode_tokens += max_new_tokens;
|
||||||
|
|
||||||
|
if prefill_tokens > prefill_token_budget
|
||||||
|
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||||
|
{
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokens = entry.request.input_length
|
||||||
|
+ entry.request.stopping_parameters.max_new_tokens
|
||||||
|
+ self.speculate
|
||||||
|
- 1;
|
||||||
|
|
||||||
|
match block_allocator.allocate(tokens).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
Some(block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
Some(block_allocation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!("Accepting entry");
|
||||||
|
// Create a new span to link the batch back to this entry
|
||||||
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
|
let (blocks, slots) = match &block_allocation {
|
||||||
|
None => (Vec::new(), Vec::new()),
|
||||||
|
Some(block_allocation) => (
|
||||||
|
block_allocation.blocks.clone(),
|
||||||
|
block_allocation.slots.clone(),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
entry.block_allocation = block_allocation;
|
||||||
|
|
||||||
|
batch_requests.push(Request {
|
||||||
|
id,
|
||||||
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: entry.request.inputs.clone(),
|
||||||
|
}),
|
||||||
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
|
truncate: entry.request.truncate,
|
||||||
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
|
entry.request.parameters.clone(),
|
||||||
|
)),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||||
|
entry.request.stopping_parameters.clone(),
|
||||||
|
)),
|
||||||
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
});
|
||||||
|
// Set batch_time
|
||||||
|
entry.batch_time = Some(Instant::now());
|
||||||
|
// Insert in batch_entries IntMap
|
||||||
|
batch_entries.insert(id, entry);
|
||||||
|
|
||||||
|
// Check if max_size
|
||||||
|
if Some(batch_requests.len()) == max_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty batch
|
||||||
|
if batch_requests.is_empty() {
|
||||||
|
tracing::debug!("Filterered out all entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if our batch is big enough
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
// Batch is too small
|
||||||
|
if batch_requests.len() < min_size {
|
||||||
|
// Add back entries to the queue in the correct order
|
||||||
|
for r in batch_requests.into_iter().rev() {
|
||||||
|
let id = r.id;
|
||||||
|
let entry = batch_entries.remove(&id).unwrap();
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
}
|
||||||
|
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final batch size
|
||||||
|
let size = batch_requests.len() as u32;
|
||||||
|
next_batch_span.record("batch_size", size);
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: self.next_batch_id,
|
||||||
|
requests: batch_requests,
|
||||||
|
size,
|
||||||
|
max_tokens: (prefill_tokens + decode_tokens),
|
||||||
|
max_blocks,
|
||||||
|
};
|
||||||
|
// Increment batch id
|
||||||
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||||
|
|
||||||
|
Some((batch_entries, batch, next_batch_span))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum QueueCommand {
|
||||||
|
Append(Box<Entry>, Span),
|
||||||
|
NextBatch {
|
||||||
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||||
|
span: Span,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||||
|
fn from(value: ValidParameters) -> Self {
|
||||||
|
let (grammar, grammar_type) = match value.grammar {
|
||||||
|
None => (String::new(), GrammarType::None),
|
||||||
|
|
||||||
|
Some(grammar) => match grammar {
|
||||||
|
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||||
|
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
temperature: value.temperature,
|
||||||
|
top_k: value.top_k,
|
||||||
|
top_p: value.top_p,
|
||||||
|
typical_p: value.typical_p,
|
||||||
|
do_sample: value.do_sample,
|
||||||
|
seed: value.seed,
|
||||||
|
repetition_penalty: value.repetition_penalty,
|
||||||
|
frequency_penalty: value.frequency_penalty,
|
||||||
|
watermark: value.watermark,
|
||||||
|
grammar,
|
||||||
|
grammar_type: grammar_type.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||||
|
fn from(value: ValidStoppingParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
max_new_tokens: value.max_new_tokens,
|
||||||
|
stop_sequences: value.stop_sequences,
|
||||||
|
ignore_eos_token: value.ignore_eos_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tracing::info_span;
|
||||||
|
|
||||||
|
fn default_entry() -> (
|
||||||
|
Entry,
|
||||||
|
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||||
|
) {
|
||||||
|
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
let entry = Entry {
|
||||||
|
request: ValidGenerateRequest {
|
||||||
|
inputs: vec![],
|
||||||
|
input_length: 0,
|
||||||
|
truncate: 0,
|
||||||
|
decoder_input_details: false,
|
||||||
|
parameters: ValidParameters {
|
||||||
|
temperature: 0.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 0.0,
|
||||||
|
typical_p: 0.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 0.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: None,
|
||||||
|
},
|
||||||
|
stopping_parameters: ValidStoppingParameters {
|
||||||
|
ignore_eos_token: false,
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
},
|
||||||
|
top_n_tokens: 0,
|
||||||
|
},
|
||||||
|
response_tx,
|
||||||
|
span: info_span!("entry"),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
block_allocation: None,
|
||||||
|
};
|
||||||
|
(entry, receiver_tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_append() {
|
||||||
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 0);
|
||||||
|
assert_eq!(state.entries.len(), 0);
|
||||||
|
|
||||||
|
state.append(entry);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 1);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
let (id, _) = state.entries.remove(0).unwrap();
|
||||||
|
assert_eq!(id, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_empty() {
|
||||||
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
|
|
||||||
|
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_min_size() {
|
||||||
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 2);
|
||||||
|
assert_eq!(state.entries.len(), 0);
|
||||||
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
state.append(entry3);
|
||||||
|
|
||||||
|
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 3);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
let (id, _) = state.entries.remove(0).unwrap();
|
||||||
|
assert_eq!(id, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_max_size() {
|
||||||
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 2);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_token_budget() {
|
||||||
|
let mut state = State::new(false, 1, None, 0, 2);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 2);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
state.append(entry3);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.contains_key(&2));
|
||||||
|
assert_eq!(batch.id, 1);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 3);
|
||||||
|
assert_eq!(state.entries.len(), 0);
|
||||||
|
assert_eq!(state.next_batch_id, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_append() {
|
||||||
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
|
let (entry, _guard) = default_entry();
|
||||||
|
queue.append(entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_empty() {
|
||||||
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
|
|
||||||
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_min_size() {
|
||||||
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
queue.append(entry3);
|
||||||
|
|
||||||
|
// Not enough requests pending
|
||||||
|
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||||
|
// Not enough token budget
|
||||||
|
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||||
|
// Ok
|
||||||
|
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries2.len(), 1);
|
||||||
|
assert!(entries2.contains_key(&2));
|
||||||
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch2.id, 1);
|
||||||
|
assert_eq!(batch2.size, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_max_size() {
|
||||||
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_token_budget() {
|
||||||
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
queue.append(entry3);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.contains_key(&2));
|
||||||
|
assert_eq!(batch.id, 1);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
|
let queue = Queue::new(false, 1, None, 2, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
// Budget of 1 is not enough
|
||||||
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
|
let (entry, _) = default_entry();
|
||||||
|
queue.append(entry);
|
||||||
|
|
||||||
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
}
|
||||||
|
}
|
1184
router/src/infer/v3/scheduler.rs
Normal file
1184
router/src/infer/v3/scheduler.rs
Normal file
File diff suppressed because it is too large
Load Diff
247
router/src/kserve.rs
Normal file
247
router/src/kserve.rs
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
use crate::{
|
||||||
|
default_parameters,
|
||||||
|
server::{generate_internal, ComputeType},
|
||||||
|
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
|
||||||
|
};
|
||||||
|
use axum::extract::{Extension, Path};
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::Json;
|
||||||
|
use futures::stream::FuturesUnordered;
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
use reqwest::header::HeaderMap;
|
||||||
|
use reqwest::StatusCode;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct OutputChunk {
|
||||||
|
pub name: String,
|
||||||
|
pub shape: Vec<usize>,
|
||||||
|
pub datatype: String,
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct InferenceOutput {
|
||||||
|
pub id: String,
|
||||||
|
pub outputs: Vec<OutputChunk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct InferenceRequest {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(default = "default_parameters")]
|
||||||
|
pub parameters: GenerateParameters,
|
||||||
|
pub inputs: Vec<Input>,
|
||||||
|
pub outputs: Vec<Output>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct Input {
|
||||||
|
pub name: String,
|
||||||
|
pub shape: Vec<usize>,
|
||||||
|
pub datatype: String,
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct Output {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct LiveResponse {
|
||||||
|
pub live: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct ReadyResponse {
|
||||||
|
pub live: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct MetadataServerResponse {
|
||||||
|
pub name: String,
|
||||||
|
pub version: String,
|
||||||
|
pub extensions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/health/live",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Service is live", body = LiveReponse),
|
||||||
|
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = LiveResponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/health/ready",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Service is ready", body = ReadyResponse),
|
||||||
|
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = ReadyResponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Metadata retrieved", body = MetadataServerResponse),
|
||||||
|
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = MetadataServerResponse {
|
||||||
|
name: "text-generation-inference".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
extensions: vec![
|
||||||
|
"health".to_string(),
|
||||||
|
"models".to_string(),
|
||||||
|
"metrics".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse),
|
||||||
|
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_model_metadata(
|
||||||
|
Path((model_name, model_version)): Path<(String, String)>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = MetadataServerResponse {
|
||||||
|
name: model_name,
|
||||||
|
version: model_version,
|
||||||
|
extensions: vec!["infer".to_string(), "ready".to_string()],
|
||||||
|
};
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}/infer",
|
||||||
|
request_body = Json<InferenceRequest>,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Inference executed successfully", body = InferenceOutput),
|
||||||
|
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_model_infer(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Json(payload): Json<InferenceRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let id = payload.id.clone();
|
||||||
|
let str_inputs = payload
|
||||||
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.map(|input| {
|
||||||
|
std::str::from_utf8(&input.data).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "utf8".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
if str_inputs.len() != payload.outputs.len() {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Inputs and outputs length mismatch".to_string(),
|
||||||
|
error_type: "length mismatch".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_chunks = str_inputs
|
||||||
|
.iter()
|
||||||
|
.zip(&payload.outputs)
|
||||||
|
.map(|(str_input, output)| {
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: str_input.to_string(),
|
||||||
|
parameters: payload.parameters.clone(),
|
||||||
|
};
|
||||||
|
let infer = infer.clone();
|
||||||
|
let compute_type = compute_type.clone();
|
||||||
|
let span = tracing::Span::current();
|
||||||
|
async move {
|
||||||
|
generate_internal(infer, compute_type, Json(generate_request), span)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| {
|
||||||
|
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
|
||||||
|
OutputChunk {
|
||||||
|
name: output.name.clone(),
|
||||||
|
shape: vec![1, generation_as_bytes.len()],
|
||||||
|
datatype: "BYTES".to_string(),
|
||||||
|
data: generation_as_bytes,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<FuturesUnordered<_>>()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let inference_output = InferenceOutput {
|
||||||
|
id: id.clone(),
|
||||||
|
outputs: output_chunks,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((HeaderMap::new(), Json(inference_output)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model version is ready", body = ReadyResponse),
|
||||||
|
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_model_metadata_ready(
|
||||||
|
Path((_model_name, _model_version)): Path<(String, String)>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = ReadyResponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
@ -1,27 +1,17 @@
|
|||||||
pub mod config;
|
|
||||||
mod health;
|
|
||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
|
pub mod config;
|
||||||
mod infer;
|
mod infer;
|
||||||
mod queue;
|
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use infer::{Infer, InferError, InferStreamResponse};
|
#[cfg(feature = "kserve")]
|
||||||
use queue::{Entry, Queue};
|
mod kserve;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::sync::OwnedSemaphorePermit;
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
/// Type alias for generation responses
|
|
||||||
pub(crate) type GenerateStreamResponse = (
|
|
||||||
OwnedSemaphorePermit,
|
|
||||||
u32, // input_length
|
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
pub(crate) struct VertexInstance {
|
pub(crate) struct VertexInstance {
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
@ -80,6 +70,20 @@ impl HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
|
pub struct HubProcessorConfig {
|
||||||
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
|
pub image_seq_len: usize,
|
||||||
|
pub processor_class: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HubProcessorConfig {
|
||||||
|
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||||
|
let content = std::fs::read_to_string(filename).ok()?;
|
||||||
|
serde_json::from_str(&content).ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
@ -88,6 +92,7 @@ pub(crate) enum GrammarType {
|
|||||||
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||||
/// with types and descriptions.
|
/// with types and descriptions.
|
||||||
#[serde(rename = "json")]
|
#[serde(rename = "json")]
|
||||||
|
#[serde(alias = "json_object")]
|
||||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||||
Json(serde_json::Value),
|
Json(serde_json::Value),
|
||||||
#[serde(rename = "regex")]
|
#[serde(rename = "regex")]
|
||||||
@ -144,7 +149,7 @@ pub struct Info {
|
|||||||
#[schema(example = "4")]
|
#[schema(example = "4")]
|
||||||
pub max_stop_sequences: usize,
|
pub max_stop_sequences: usize,
|
||||||
#[schema(example = "1024")]
|
#[schema(example = "1024")]
|
||||||
pub max_input_length: usize,
|
pub max_input_tokens: usize,
|
||||||
#[schema(example = "2048")]
|
#[schema(example = "2048")]
|
||||||
pub max_total_tokens: usize,
|
pub max_total_tokens: usize,
|
||||||
#[schema(example = "1.2")]
|
#[schema(example = "1.2")]
|
||||||
@ -402,6 +407,11 @@ pub struct CompletionRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(example = "1.0")]
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
@ -785,6 +795,13 @@ pub(crate) struct ChatRequest {
|
|||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||||
pub tool_choice: Option<ToolType>,
|
pub tool_choice: Option<ToolType>,
|
||||||
|
|
||||||
|
/// Response format constraints for the generation.
|
||||||
|
///
|
||||||
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub response_format: Option<GrammarType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
fn default_tool_prompt() -> Option<String> {
|
||||||
@ -1068,7 +1085,7 @@ pub struct SimpleToken {
|
|||||||
stop: usize,
|
stop: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
#[schema(example = "Length")]
|
#[schema(example = "Length")]
|
||||||
pub(crate) enum FinishReason {
|
pub(crate) enum FinishReason {
|
||||||
|
@ -12,15 +12,14 @@ use std::fs::File;
|
|||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
|
||||||
use text_generation_router::config::Config;
|
use text_generation_router::config::Config;
|
||||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::{EnvFilter, Layer};
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -206,11 +205,18 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
let (
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
|
model_info,
|
||||||
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
Some(local_path.join("tokenizer.json")),
|
Some(local_path.join("tokenizer.json")),
|
||||||
Some(local_path.join("config.json")),
|
Some(local_path.join("config.json")),
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
|
Some(local_path.join("processor_config.json")),
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
Type::Api(api) => {
|
Type::Api(api) => {
|
||||||
@ -226,6 +232,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
};
|
};
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||||
Some(model_info)
|
Some(model_info)
|
||||||
@ -237,6 +244,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
tokenizer_filename,
|
tokenizer_filename,
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
model_info,
|
model_info,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -250,6 +258,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
repo.get("tokenizer.json"),
|
repo.get("tokenizer.json"),
|
||||||
repo.get("config.json"),
|
repo.get("config.json"),
|
||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
|
repo.get("processor_config.json"),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -286,6 +295,10 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let processor_config = processor_config_filename
|
||||||
|
.and_then(HubProcessorConfig::from_file)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
tracing::info!("Using config {config:?}");
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
@ -301,59 +314,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||||
};
|
};
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
|
||||||
.await
|
|
||||||
.map_err(RouterError::Connection)?;
|
|
||||||
// Clear the cache; useful if the webserver rebooted
|
|
||||||
sharded_client
|
|
||||||
.clear_cache(None)
|
|
||||||
.await
|
|
||||||
.map_err(RouterError::Cache)?;
|
|
||||||
// Get info from the shard
|
|
||||||
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
|
||||||
|
|
||||||
// Warmup model
|
|
||||||
tracing::info!("Warming up model");
|
|
||||||
let max_supported_batch_total_tokens = match sharded_client
|
|
||||||
.warmup(
|
|
||||||
max_input_tokens as u32,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_total_tokens as u32,
|
|
||||||
max_batch_size,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(RouterError::Warmup)?
|
|
||||||
{
|
|
||||||
// Older models do not support automatic max-batch-total-tokens
|
|
||||||
None => {
|
|
||||||
let max_batch_total_tokens = max_batch_total_tokens
|
|
||||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
|
||||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
|
||||||
max_batch_total_tokens
|
|
||||||
}
|
|
||||||
// Flash attention models return their max supported total tokens
|
|
||||||
Some(max_supported_batch_total_tokens) => {
|
|
||||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
|
||||||
if max_batch_total_tokens.is_some() {
|
|
||||||
tracing::warn!(
|
|
||||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
|
||||||
Attention models."
|
|
||||||
);
|
|
||||||
tracing::warn!(
|
|
||||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
max_supported_batch_total_tokens
|
|
||||||
}
|
|
||||||
};
|
|
||||||
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
|
||||||
tracing::info!("Connected");
|
|
||||||
|
|
||||||
// Determine the server port based on the feature and environment variable.
|
// Determine the server port based on the feature and environment variable.
|
||||||
let port = if cfg!(feature = "google") {
|
let port = if cfg!(feature = "google") {
|
||||||
std::env::var("AIP_HTTP_PORT")
|
std::env::var("AIP_HTTP_PORT")
|
||||||
@ -373,8 +333,8 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
|
master_shard_uds_path,
|
||||||
model_info,
|
model_info,
|
||||||
shard_info,
|
|
||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
@ -384,10 +344,9 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_supported_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
sharded_client,
|
|
||||||
tokenizer,
|
tokenizer,
|
||||||
config,
|
config,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
@ -397,6 +356,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
|
processor_config,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
@ -454,8 +414,21 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
// Filter events with LOG_LEVEL
|
||||||
let env_filter =
|
let varname = "LOG_LEVEL";
|
||||||
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||||
|
// Override to avoid simple logs to be spammed with tokio level informations
|
||||||
|
let log_level = match &log_level[..] {
|
||||||
|
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
||||||
|
"info" => "text_generation_launcher=info,text_generation_router=info",
|
||||||
|
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
||||||
|
log_level => log_level,
|
||||||
|
};
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.parse_lossy(log_level)
|
||||||
|
} else {
|
||||||
|
EnvFilter::new("info")
|
||||||
|
};
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
tracing_subscriber::registry()
|
||||||
.with(env_filter)
|
.with(env_filter)
|
||||||
@ -529,16 +502,8 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
|
|||||||
enum RouterError {
|
enum RouterError {
|
||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
ArgumentValidation(String),
|
ArgumentValidation(String),
|
||||||
#[error("Unable to connect to the Python model shards: {0}")]
|
#[error("WebServer error: {0}")]
|
||||||
Connection(ClientError),
|
WebServer(#[from] server::WebServerError),
|
||||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
|
||||||
Cache(ClientError),
|
|
||||||
#[error("Unable to get the Python model shards info: {0}")]
|
|
||||||
Info(ClientError),
|
|
||||||
#[error("Unable to warmup the Python model shards: {0}")]
|
|
||||||
Warmup(ClientError),
|
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
Tokio(#[from] std::io::Error),
|
Tokio(#[from] std::io::Error),
|
||||||
#[error("Axum webserver failed: {0}")]
|
|
||||||
Axum(#[from] axum::BoxError),
|
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,20 @@
|
|||||||
use crate::config::Config;
|
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::health::Health;
|
use crate::config::Config;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
use crate::infer::v2::SchedulerV2;
|
||||||
|
use crate::infer::v3::SchedulerV3;
|
||||||
|
use crate::infer::{HealthCheck, Scheduler};
|
||||||
|
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
use crate::kserve::{
|
||||||
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
|
||||||
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
|
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
|
||||||
Validation,
|
Usage, Validation,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
@ -34,7 +41,8 @@ use std::convert::Infallible;
|
|||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{v2, v3, ClientError, ShardInfo};
|
||||||
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
@ -115,7 +123,9 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
|||||||
)]
|
)]
|
||||||
#[instrument(skip(health))]
|
#[instrument(skip(health))]
|
||||||
/// Health check method
|
/// Health check method
|
||||||
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(
|
||||||
|
mut health: Extension<HealthCheck>,
|
||||||
|
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
match health.check().await {
|
match health.check().await {
|
||||||
true => Ok(()),
|
true => Ok(()),
|
||||||
false => Err((
|
false => Err((
|
||||||
@ -167,7 +177,7 @@ async fn generate(
|
|||||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_internal(
|
pub(crate) async fn generate_internal(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
@ -213,9 +223,7 @@ async fn generate_internal(
|
|||||||
|
|
||||||
BestOfSequence {
|
BestOfSequence {
|
||||||
generated_text: output_text,
|
generated_text: output_text,
|
||||||
finish_reason: FinishReason::from(
|
finish_reason: response.generated_text.finish_reason,
|
||||||
response.generated_text.finish_reason,
|
|
||||||
),
|
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: response.prefill,
|
prefill: response.prefill,
|
||||||
tokens: response.tokens,
|
tokens: response.tokens,
|
||||||
@ -227,7 +235,7 @@ async fn generate_internal(
|
|||||||
});
|
});
|
||||||
|
|
||||||
Some(Details {
|
Some(Details {
|
||||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
finish_reason: response.generated_text.finish_reason,
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: response.prefill,
|
prefill: response.prefill,
|
||||||
tokens: response.tokens,
|
tokens: response.tokens,
|
||||||
@ -468,7 +476,7 @@ async fn generate_stream_internal(
|
|||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(StreamDetails {
|
true => Some(StreamDetails {
|
||||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
finish_reason: generated_text.finish_reason,
|
||||||
generated_tokens: generated_text.generated_tokens,
|
generated_tokens: generated_text.generated_tokens,
|
||||||
seed: generated_text.seed,
|
seed: generated_text.seed,
|
||||||
}),
|
}),
|
||||||
@ -597,9 +605,22 @@ async fn completions(
|
|||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let stream = req.stream;
|
let CompletionRequest {
|
||||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
max_tokens,
|
||||||
let seed = req.seed;
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
temperature,
|
||||||
|
..
|
||||||
|
} = req;
|
||||||
|
|
||||||
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
|
||||||
// if suffix is present throw an error
|
// if suffix is present throw an error
|
||||||
if req.suffix.is_some() {
|
if req.suffix.is_some() {
|
||||||
@ -635,16 +656,16 @@ async fn completions(
|
|||||||
inputs: prompt.to_string(),
|
inputs: prompt.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature,
|
||||||
repetition_penalty: req.repetition_penalty,
|
repetition_penalty: req.repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: true,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: stop.clone(),
|
||||||
truncate: None,
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
details: true,
|
details: true,
|
||||||
@ -1000,6 +1021,7 @@ async fn chat_completions(
|
|||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
temperature,
|
temperature,
|
||||||
|
response_format,
|
||||||
..
|
..
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
@ -1014,6 +1036,18 @@ async fn chat_completions(
|
|||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// response_format and tools are mutually exclusive
|
||||||
|
if response_format.is_some() && tools.as_ref().is_some() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||||
|
error_type: "grammar and tools".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// extract tool grammar if present
|
// extract tool grammar if present
|
||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||||
Ok(grammar) => grammar,
|
Ok(grammar) => grammar,
|
||||||
@ -1030,16 +1064,21 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let grammar_with_prompt = tool_grammar
|
// determine the appropriate arguments for apply_chat_template
|
||||||
|
let tools_grammar_prompt = tool_grammar
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||||
|
|
||||||
let typed_grammar = grammar_with_prompt
|
let (tools_grammar_prompt, grammar) = match response_format {
|
||||||
.as_ref()
|
Some(response_format) => (None, Some(response_format)),
|
||||||
.map(|(grammar, _)| grammar.clone());
|
None => (
|
||||||
|
tools_grammar_prompt.clone(),
|
||||||
|
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
@ -1075,7 +1114,7 @@ async fn chat_completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar: typed_grammar,
|
grammar,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1320,7 +1359,8 @@ async fn tokenize(
|
|||||||
.iter()
|
.iter()
|
||||||
.zip(encoding.get_offsets())
|
.zip(encoding.get_offsets())
|
||||||
.map(|(&id, &(start, stop))| {
|
.map(|(&id, &(start, stop))| {
|
||||||
let text: String = input.chars().skip(start).take(stop - start).collect();
|
let text: String =
|
||||||
|
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
|
||||||
SimpleToken {
|
SimpleToken {
|
||||||
id,
|
id,
|
||||||
text,
|
text,
|
||||||
@ -1358,34 +1398,34 @@ pub(crate) struct ComputeType(String);
|
|||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
|
master_shard_uds_path: String,
|
||||||
model_info: HubModelInfo,
|
model_info: HubModelInfo,
|
||||||
shard_info: ShardInfo,
|
|
||||||
compat_return_full_text: bool,
|
compat_return_full_text: bool,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_tokens: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: Option<u32>,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
client: ShardedClient,
|
|
||||||
tokenizer: Option<Tokenizer>,
|
tokenizer: Option<Tokenizer>,
|
||||||
config: Option<Config>,
|
config: Option<Config>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
|
processor_config: HubProcessorConfig,
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
grammar_support: bool,
|
grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), WebServerError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
@ -1455,6 +1495,141 @@ pub async fn run(
|
|||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
|
|
||||||
|
// Open connection, get model info and warmup
|
||||||
|
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||||
|
Arc<dyn Scheduler + Send + Sync>,
|
||||||
|
HealthCheck,
|
||||||
|
ShardInfo,
|
||||||
|
u32,
|
||||||
|
) = {
|
||||||
|
// Helper function to check both v2 and v3
|
||||||
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
|
match max_supported_batch_total_tokens {
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||||
|
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
||||||
|
);
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
Ok(max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(max_supported_batch_total_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
|
||||||
|
Ok(mut sharded_client) => {
|
||||||
|
// server is running on v3
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Warmup)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let health_ext =
|
||||||
|
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
||||||
|
let scheduler = Arc::new(SchedulerV3::new(
|
||||||
|
sharded_client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
generation_health,
|
||||||
|
));
|
||||||
|
tracing::info!("Using scheduler V3");
|
||||||
|
|
||||||
|
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Connection)?;
|
||||||
|
|
||||||
|
// server is running on v2
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(WebServerError::Warmup)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let health_ext =
|
||||||
|
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
||||||
|
let scheduler = Arc::new(SchedulerV2::new(
|
||||||
|
sharded_client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
generation_health,
|
||||||
|
));
|
||||||
|
tracing::info!("Using scheduler V2");
|
||||||
|
|
||||||
|
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
validation_workers,
|
validation_workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -1462,26 +1637,17 @@ pub async fn run(
|
|||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
grammar_support,
|
grammar_support,
|
||||||
);
|
);
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
|
||||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
scheduler,
|
||||||
validation,
|
validation,
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
|
||||||
shard_info.window_size,
|
|
||||||
shard_info.speculate,
|
|
||||||
generation_health,
|
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
|
processor_config,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
@ -1498,7 +1664,7 @@ pub async fn run(
|
|||||||
// Input Length buckets
|
// Input Length buckets
|
||||||
let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
|
let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
|
||||||
let input_length_buckets: Vec<f64> = (0..100)
|
let input_length_buckets: Vec<f64> = (0..100)
|
||||||
.map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64)
|
.map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64)
|
||||||
.collect();
|
.collect();
|
||||||
// Generated tokens buckets
|
// Generated tokens buckets
|
||||||
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
|
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
|
||||||
@ -1552,7 +1718,7 @@ pub async fn run(
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
@ -1566,9 +1732,9 @@ pub async fn run(
|
|||||||
docker_label: option_env!("DOCKER_LABEL"),
|
docker_label: option_env!("DOCKER_LABEL"),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
#[allow(unused_mut)] // mut is needed for conditional compilation
|
||||||
let doc = {
|
let mut doc = ApiDoc::openapi();
|
||||||
// avoid `mut` if possible
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
{
|
{
|
||||||
use crate::VertexInstance;
|
use crate::VertexInstance;
|
||||||
@ -1578,16 +1744,46 @@ pub async fn run(
|
|||||||
paths(vertex_compatibility),
|
paths(vertex_compatibility),
|
||||||
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||||
)]
|
)]
|
||||||
struct VertextApiDoc;
|
struct VertexApiDoc;
|
||||||
|
|
||||||
// limiting mutability to the smallest scope necessary
|
doc.merge(VertexApiDoc::openapi());
|
||||||
let mut doc = ApiDoc::openapi();
|
|
||||||
doc.merge(VertextApiDoc::openapi());
|
|
||||||
doc
|
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "google"))]
|
|
||||||
ApiDoc::openapi()
|
#[cfg(feature = "kserve")]
|
||||||
|
{
|
||||||
|
use crate::kserve::{
|
||||||
|
InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk,
|
||||||
|
ReadyResponse,
|
||||||
};
|
};
|
||||||
|
use crate::kserve::{
|
||||||
|
__path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready,
|
||||||
|
__path_kserve_model_infer, __path_kserve_model_metadata,
|
||||||
|
__path_kserve_model_metadata_ready,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(
|
||||||
|
kserve_model_infer,
|
||||||
|
kserve_health_live,
|
||||||
|
kserve_health_ready,
|
||||||
|
kerve_server_metadata,
|
||||||
|
kserve_model_metadata,
|
||||||
|
kserve_model_metadata_ready,
|
||||||
|
),
|
||||||
|
components(schemas(
|
||||||
|
InferenceOutput,
|
||||||
|
InferenceRequest,
|
||||||
|
LiveResponse,
|
||||||
|
MetadataServerResponse,
|
||||||
|
OutputChunk,
|
||||||
|
ReadyResponse,
|
||||||
|
))
|
||||||
|
)]
|
||||||
|
struct KServeApiDoc;
|
||||||
|
|
||||||
|
doc.merge(KServeApiDoc::openapi());
|
||||||
|
}
|
||||||
|
|
||||||
// Configure Swagger UI
|
// Configure Swagger UI
|
||||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||||
@ -1637,6 +1833,27 @@ pub async fn run(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
{
|
||||||
|
tracing::info!("Built with `kserve` feature");
|
||||||
|
app = app
|
||||||
|
.route(
|
||||||
|
"/v2/models/:model_name/versions/:model_version/infer",
|
||||||
|
post(kserve_model_infer),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/v2/models/:model_name/versions/:model_version",
|
||||||
|
get(kserve_model_metadata),
|
||||||
|
)
|
||||||
|
.route("/v2/health/ready", get(kserve_health_ready))
|
||||||
|
.route("/v2/health/live", get(kserve_health_live))
|
||||||
|
.route("/v2", get(kerve_server_metadata))
|
||||||
|
.route(
|
||||||
|
"/v2/models/:model_name/versions/:model_version/ready",
|
||||||
|
get(kserve_model_metadata_ready),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// add layers after routes
|
// add layers after routes
|
||||||
app = app
|
app = app
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
@ -1648,49 +1865,14 @@ pub async fn run(
|
|||||||
.layer(OtelAxumLayer::default())
|
.layer(OtelAxumLayer::default())
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
|
||||||
|
tracing::info!("Connected");
|
||||||
|
|
||||||
if ngrok {
|
if ngrok {
|
||||||
#[cfg(feature = "ngrok")]
|
#[cfg(feature = "ngrok")]
|
||||||
{
|
{
|
||||||
use ngrok::config::TunnelBuilder;
|
panic!("ngrok feature is not functional with axum=0.7 and hyper=1, waiting on https://github.com/ngrok/ngrok-rust/pull/137/files to re-enable.");
|
||||||
|
|
||||||
let _ = addr;
|
|
||||||
|
|
||||||
let authtoken =
|
|
||||||
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
|
|
||||||
|
|
||||||
let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");
|
|
||||||
|
|
||||||
let tunnel = ngrok::Session::builder()
|
|
||||||
.authtoken(authtoken)
|
|
||||||
.connect()
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.labeled_tunnel()
|
|
||||||
.label("edge", edge);
|
|
||||||
|
|
||||||
let listener = tunnel.listen().await.unwrap();
|
|
||||||
|
|
||||||
// Run prom metrics and health locally too
|
|
||||||
tokio::spawn(
|
|
||||||
axum::Server::bind(&addr)
|
|
||||||
.serve(
|
|
||||||
Router::new()
|
|
||||||
.route("/health", get(health))
|
|
||||||
.route("/metrics", get(metrics))
|
|
||||||
.layer(Extension(health_ext))
|
|
||||||
.layer(Extension(prom_handle))
|
|
||||||
.into_make_service(),
|
|
||||||
)
|
|
||||||
//Wait until all requests are finished to shut down
|
|
||||||
.with_graceful_shutdown(shutdown_signal()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
axum::Server::builder(listener)
|
|
||||||
.serve(app.into_make_service())
|
|
||||||
//Wait until all requests are finished to shut down
|
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "ngrok"))]
|
#[cfg(not(feature = "ngrok"))]
|
||||||
{
|
{
|
||||||
@ -1703,11 +1885,12 @@ pub async fn run(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Run server
|
// Run server
|
||||||
axum::Server::bind(&addr)
|
|
||||||
.serve(app.into_make_service())
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||||
// Wait until all requests are finished to shut down
|
axum::serve(listener, app)
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(|err| WebServerError::Axum(Box::new(err)))?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1740,17 +1923,6 @@ async fn shutdown_signal() {
|
|||||||
opentelemetry::global::shutdown_tracer_provider();
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<i32> for FinishReason {
|
|
||||||
fn from(finish_reason: i32) -> Self {
|
|
||||||
let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
|
|
||||||
match finish_reason {
|
|
||||||
text_generation_client::FinishReason::Length => FinishReason::Length,
|
|
||||||
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
|
||||||
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert to Axum supported formats
|
/// Convert to Axum supported formats
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
@ -1783,3 +1955,19 @@ impl From<InferError> for Event {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum WebServerError {
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
|
NotEnoughMemory(usize),
|
||||||
|
#[error("Axum error: {0}")]
|
||||||
|
Axum(#[from] axum::BoxError),
|
||||||
|
}
|
||||||
|
@ -1,19 +1,16 @@
|
|||||||
use crate::config::Config;
|
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
|
use crate::config::Config;
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||||
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
|
use image::{io::Reader as ImageReader, ImageFormat};
|
||||||
use jsonschema::{Draft, JSONSchema};
|
use jsonschema::{Draft, JSONSchema};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use text_generation_client::{
|
use text_generation_client::{Chunk, Image, InputChunk};
|
||||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
// use tokenizers::TruncationDirection;
|
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
|
||||||
use image::{io::Reader as ImageReader, ImageFormat};
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
@ -89,7 +86,7 @@ impl Validation {
|
|||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
|
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
if let Some(sender) = &self.sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -115,7 +112,7 @@ impl Validation {
|
|||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(String, usize, u32), ValidationError> {
|
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -172,13 +169,13 @@ impl Validation {
|
|||||||
// Validate MaxNewTokens
|
// Validate MaxNewTokens
|
||||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||||
// return Err(ValidationError::MaxNewTokens(
|
|
||||||
// self.max_total_tokens - self.max_input_length,
|
|
||||||
// max_new_tokens,
|
|
||||||
// ));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((inputs, input_length, max_new_tokens))
|
Ok((
|
||||||
|
vec![Chunk::Text(inputs).into()],
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,13 +319,13 @@ impl Validation {
|
|||||||
// compiler and use that to build the FSM here.
|
// compiler and use that to build the FSM here.
|
||||||
|
|
||||||
// Validate grammar and unpack the grammar and type for the proto message
|
// Validate grammar and unpack the grammar and type for the proto message
|
||||||
let (grammar, grammar_type) = match grammar {
|
let grammar = match grammar {
|
||||||
Some(grammar) => {
|
Some(grammar) => {
|
||||||
// Ensure that grammar is not set if it's not supported
|
// Ensure that grammar is not set if it's not supported
|
||||||
if self.disable_grammar_support {
|
if self.disable_grammar_support {
|
||||||
return Err(ValidationError::Grammar);
|
return Err(ValidationError::Grammar);
|
||||||
}
|
}
|
||||||
match grammar {
|
let valid_grammar = match grammar {
|
||||||
GrammarType::Json(json) => {
|
GrammarType::Json(json) => {
|
||||||
let json = match json {
|
let json = match json {
|
||||||
// if value is a string, we need to parse it again to make sure its
|
// if value is a string, we need to parse it again to make sure its
|
||||||
@ -345,20 +342,20 @@ impl Validation {
|
|||||||
.compile(&json)
|
.compile(&json)
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
|
||||||
(
|
|
||||||
// Serialize json to string
|
// Serialize json to string
|
||||||
|
ValidGrammar::Json(
|
||||||
serde_json::to_string(&json)
|
serde_json::to_string(&json)
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||||
ProtoGrammarType::Json.into(),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
||||||
|
};
|
||||||
|
Some(valid_grammar)
|
||||||
}
|
}
|
||||||
}
|
None => None,
|
||||||
None => (String::new(), ProtoGrammarType::None.into()),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = ValidParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
@ -369,9 +366,8 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
grammar_type,
|
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = ValidStoppingParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
@ -453,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
|||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_to_mimetype(format: ImageFormat) -> String {
|
fn format_to_mimetype(format: ImageFormat) -> String {
|
||||||
match format {
|
match format {
|
||||||
ImageFormat::Png => "image/png",
|
ImageFormat::Png => "image/png",
|
||||||
@ -465,7 +462,7 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||||
if input.starts_with(" || input.starts_with(" {
|
if input.starts_with(" || input.starts_with(" {
|
||||||
let url = &input["..input.len() - 1];
|
let url = &input["..input.len() - 1];
|
||||||
let data = reqwest::blocking::get(url)?.bytes()?;
|
let data = reqwest::blocking::get(url)?.bytes()?;
|
||||||
@ -476,9 +473,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||||||
let height: usize = img.height().try_into()?;
|
let height: usize = img.height().try_into()?;
|
||||||
let width: usize = img.width().try_into()?;
|
let width: usize = img.width().try_into()?;
|
||||||
let mimetype = format_to_mimetype(format);
|
let mimetype = format_to_mimetype(format);
|
||||||
let encoded = STANDARD.encode(data);
|
Ok((data.to_vec(), mimetype, height, width))
|
||||||
let data_uri = format!("");
|
|
||||||
Ok((data_uri, height, width))
|
|
||||||
} else if input.starts_with(" {
|
} else if input.starts_with(" {
|
||||||
// Remove 
|
// Remove 
|
||||||
let content = &input["..input.len() - 1];
|
let content = &input["..input.len() - 1];
|
||||||
@ -495,9 +490,9 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||||||
|
|
||||||
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
||||||
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
||||||
ImageReader::with_format(Cursor::new(data), format).decode()?
|
ImageReader::with_format(Cursor::new(&data), format).decode()?
|
||||||
} else {
|
} else {
|
||||||
ImageReader::new(Cursor::new(data))
|
ImageReader::new(Cursor::new(&data))
|
||||||
.with_guessed_format()
|
.with_guessed_format()
|
||||||
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
||||||
.decode()?
|
.decode()?
|
||||||
@ -505,7 +500,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||||||
|
|
||||||
let height: usize = img.height().try_into()?;
|
let height: usize = img.height().try_into()?;
|
||||||
let width: usize = img.width().try_into()?;
|
let width: usize = img.width().try_into()?;
|
||||||
Ok((input.to_string(), height, width))
|
Ok((data, mimetype.to_string(), height, width))
|
||||||
} else {
|
} else {
|
||||||
Err(ValidationError::InvalidImageContent(input.to_string()))
|
Err(ValidationError::InvalidImageContent(input.to_string()))
|
||||||
}
|
}
|
||||||
@ -513,113 +508,110 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||||||
|
|
||||||
/// Get input length and optionally truncate it
|
/// Get input length and optionally truncate it
|
||||||
fn prepare_input(
|
fn prepare_input(
|
||||||
mut inputs: String,
|
inputs: String,
|
||||||
_truncate: Option<usize>,
|
_truncate: Option<usize>,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
config: &Option<Config>,
|
config: &Option<Config>,
|
||||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let tokenizer_query = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(Config::LlavaNext(config)) => {
|
Some(Config::LlavaNext(config)) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = config.get_number_of_features(height, width);
|
let slots = config.get_number_of_features(height, width);
|
||||||
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
modified_inputs.push_str(&image_uri);
|
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
Some(Config::Paligemma(config)) => {
|
Some(Config::Paligemma(config)) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = config.get_number_of_features(height, width);
|
let slots = config.get_number_of_features(height, width);
|
||||||
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
modified_inputs.push_str(&image_uri);
|
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
Some(Config::Idefics2(config)) => {
|
Some(Config::Idefics2(config)) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = config.get_number_of_features(height, width);
|
let slots = config.get_number_of_features(height, width);
|
||||||
tokenizer_query.push_str("<fake_token_around_image>");
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
tokenizer_query.push_str("<fake_token_around_image>");
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
|
|
||||||
modified_inputs.push_str(&image_uri);
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
Some(Config::Idefics) => {
|
Some(Config::Idefics) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, _height, _width) =
|
||||||
|
fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = 1;
|
let slots = 1;
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
modified_inputs.push_str(&image_uri);
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
_ => inputs.clone(),
|
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
@ -627,23 +619,64 @@ fn prepare_input(
|
|||||||
.encode(tokenizer_query, true)
|
.encode(tokenizer_query, true)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
Ok((encoding, inputs))
|
Ok((encoding, input_chunks))
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, Option<usize>),
|
(String, Option<usize>),
|
||||||
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) enum ValidGrammar {
|
||||||
|
Json(String),
|
||||||
|
Regex(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct ValidParameters {
|
||||||
|
/// / exponential scaling output probability distribution
|
||||||
|
pub temperature: f32,
|
||||||
|
/// / restricting to the k highest probability elements
|
||||||
|
pub top_k: u32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
pub top_p: f32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
pub typical_p: f32,
|
||||||
|
/// / apply sampling on the logits
|
||||||
|
pub do_sample: bool,
|
||||||
|
/// / random seed for sampling
|
||||||
|
pub seed: u64,
|
||||||
|
/// / repetition penalty
|
||||||
|
pub repetition_penalty: f32,
|
||||||
|
/// / frequency penalty
|
||||||
|
pub frequency_penalty: f32,
|
||||||
|
/// / token watermarking using "A Watermark for Large Language Models"
|
||||||
|
pub watermark: bool,
|
||||||
|
/// / grammar (applied if not empty)
|
||||||
|
pub grammar: Option<ValidGrammar>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct ValidStoppingParameters {
|
||||||
|
/// / Maximum number of generated tokens
|
||||||
|
pub max_new_tokens: u32,
|
||||||
|
/// / Optional stopping sequences
|
||||||
|
pub stop_sequences: Vec<String>,
|
||||||
|
/// / Ignore end of sequence token
|
||||||
|
/// / used for benchmarking
|
||||||
|
pub ignore_eos_token: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: Vec<InputChunk>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
pub top_n_tokens: u32,
|
pub top_n_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -714,6 +747,7 @@ pub enum ValidationError {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::config::{PaliTextConfig, Paligemma};
|
||||||
use crate::default_parameters;
|
use crate::default_parameters;
|
||||||
use crate::tests::get_tokenizer;
|
use crate::tests::get_tokenizer;
|
||||||
|
|
||||||
@ -964,4 +998,61 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(valid_request.top_n_tokens, 0);
|
assert_eq!(valid_request.top_n_tokens, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==";
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_prepare_input_chunks() {
|
||||||
|
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||||
|
|
||||||
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequence = 3;
|
||||||
|
let max_top_n_tokens = 4;
|
||||||
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
|
let disable_grammar_support = true;
|
||||||
|
let workers = 1;
|
||||||
|
let config = Config::Paligemma(Paligemma {
|
||||||
|
text_config: PaliTextConfig {
|
||||||
|
num_image_tokens: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
Some(config),
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
|
);
|
||||||
|
|
||||||
|
let chunks = match validation
|
||||||
|
.tokenize(
|
||||||
|
format!("test", PIXEL_GIF),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Some((_encoding, chunks))) => chunks,
|
||||||
|
_ => panic!("Unexpected tokenization failure"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
chunks
|
||||||
|
== vec![
|
||||||
|
Chunk::Text("test".to_string()).into(),
|
||||||
|
Chunk::Image(Image {
|
||||||
|
data: pixel_data.clone(),
|
||||||
|
mimetype: "image/gif".to_string()
|
||||||
|
})
|
||||||
|
.into()
|
||||||
|
],
|
||||||
|
"Failed to process images",
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[toolchain]
|
[toolchain]
|
||||||
# Released on: 02 May, 2024
|
# Released on: June 13, 2024
|
||||||
# https://releases.rs/docs/1.78.0/
|
# https://releases.rs/docs/1.79.0/
|
||||||
channel = "1.78.0"
|
channel = "1.79.0"
|
||||||
components = ["rustfmt", "clippy"]
|
components = ["rustfmt", "clippy"]
|
||||||
|
@ -10,18 +10,26 @@ unit-tests:
|
|||||||
|
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
|
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
|
||||||
mkdir text_generation_server/pb || true
|
mkdir text_generation_server/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \
|
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
|
||||||
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto
|
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
|
||||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch text_generation_server/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
install: gen-server
|
install-server: gen-server
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
pip install -r requirements_cuda.txt
|
pip install -r requirements_cuda.txt
|
||||||
pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
|
pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
|
||||||
|
|
||||||
|
|
||||||
|
install: install-cuda
|
||||||
|
echo "Installed server"
|
||||||
|
|
||||||
|
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
||||||
|
|
||||||
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||||
|
|
||||||
flash-attention:
|
build-flash-attention:
|
||||||
# Clone flash attention
|
if [ ! -d 'flash-attention' ]; then \
|
||||||
pip install -U packaging ninja --no-cache-dir
|
pip install -U packaging ninja --no-cache-dir && \
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git
|
git clone https://github.com/HazyResearch/flash-attention.git; \
|
||||||
|
fi
|
||||||
build-flash-attention: flash-attention
|
cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
|
||||||
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
|
||||||
cd flash-attention && python setup.py build
|
|
||||||
cd flash-attention/csrc/rotary && python setup.py build
|
|
||||||
cd flash-attention/csrc/layer_norm && python setup.py build
|
|
||||||
|
|
||||||
install-flash-attention: build-flash-attention
|
install-flash-attention: build-flash-attention
|
||||||
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
|
cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
|
||||||
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
|
|
||||||
|
@ -1,29 +1,21 @@
|
|||||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
flash_att_v2_commit_cuda := v2.5.9.post1
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||||
|
|
||||||
|
build-flash-attention-v2-cuda:
|
||||||
flash-attention-v2-cuda:
|
pip install -U packaging wheel
|
||||||
# Clone flash attention
|
pip install flash-attn==$(flash_att_v2_commit_cuda)
|
||||||
pip install -U packaging ninja --no-cache-dir
|
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
|
||||||
|
|
||||||
build-flash-attention-v2-cuda: flash-attention-v2-cuda
|
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
|
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive
|
|
||||||
cd flash-attention-v2 && python setup.py build
|
|
||||||
|
|
||||||
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
echo "Flash v2 installed"
|
||||||
|
|
||||||
flash-attention-v2-rocm:
|
build-flash-attention-v2-rocm:
|
||||||
# Clone flash attention
|
if [ ! -d 'flash-attention-v2' ]; then \
|
||||||
pip install -U packaging ninja --no-cache-dir
|
pip install -U packaging ninja --no-cache-dir && \
|
||||||
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2
|
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
|
||||||
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||||
build-flash-attention-v2-rocm: flash-attention-v2-rocm
|
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
|
fi
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive
|
|
||||||
cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
|
||||||
|
|
||||||
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
cd flash-attention-v2 && \
|
||||||
|
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
||||||
|
@ -1,25 +1,23 @@
|
|||||||
vllm-cuda:
|
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
||||||
# Clone vllm
|
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||||
pip install -U ninja packaging --no-cache-dir
|
build-vllm-cuda:
|
||||||
git clone https://github.com/Narsil/vllm.git vllm
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
build-vllm-cuda: vllm-cuda
|
git clone https://github.com/Narsil/vllm.git vllm; \
|
||||||
cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
fi
|
||||||
cd vllm && python setup.py build
|
cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build
|
||||||
|
|
||||||
install-vllm-cuda: build-vllm-cuda
|
install-vllm-cuda: build-vllm-cuda
|
||||||
pip uninstall vllm -y || true
|
cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e .
|
||||||
cd vllm && python setup.py install
|
|
||||||
|
|
||||||
vllm-rocm:
|
build-vllm-rocm:
|
||||||
# Clone vllm
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm
|
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
||||||
|
fi
|
||||||
build-vllm-rocm: vllm-rocm
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
|
||||||
|
|
||||||
install-vllm-rocm: build-vllm-rocm
|
install-vllm-rocm: build-vllm-rocm
|
||||||
pip uninstall vllm -y || true
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
cd vllm && python setup.py install
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
|
||||||
|
20
server/marlin/COPYRIGHT
Normal file
20
server/marlin/COPYRIGHT
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
These kernels were vendored from VLLM. The Marlin kernels were developed
|
||||||
|
by Elias Frantar and extended by Neural Magic.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Copyright (C) Marlin.2024 Elias Frantar
|
||||||
|
Modified by Neural Magic
|
||||||
|
Copyright 2024 The vLLM team.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
44
server/marlin/marlin_kernels/__init__.pyi
Normal file
44
server/marlin/marlin_kernels/__init__.pyi
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def gptq_marlin_gemm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
g_idx: torch.Tensor,
|
||||||
|
perm: torch.Tensor,
|
||||||
|
workspace: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
size_m: int,
|
||||||
|
size_n: int,
|
||||||
|
size_k: int,
|
||||||
|
is_k_full: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Matrix multiplication using Marlin kernels. This is an extension of
|
||||||
|
`marlin_gemm` that supports converted GPTQ kernels.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def gptq_marlin_repack(
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
perm: torch.Tensor,
|
||||||
|
size_k: int,
|
||||||
|
size_n: int,
|
||||||
|
num_bits: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Repack GPTQ parameters for Marlin kernels."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def marlin_gemm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
workspace: torch.Tensor,
|
||||||
|
size_m: int,
|
||||||
|
size_n: int,
|
||||||
|
size_k: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Matrix multiplication using Marlin kernels.
|
||||||
|
"""
|
||||||
|
...
|
11
server/marlin/marlin_kernels/ext.cpp
Normal file
11
server/marlin/marlin_kernels/ext.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include "ext.hh"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||||
|
"Marlin gemm with GPTQ compatibility");
|
||||||
|
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||||
|
"Repack GPTQ parameters for Marlin");
|
||||||
|
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user